Skip to main content

starpod_db/
lib.rs

1//! Unified SQLite database for Starpod's transactional data.
2//!
3//! All transactional data (sessions, cron scheduling, authentication) lives in
4//! a single `core.db` file. A shared connection pool (WAL mode, foreign keys
5//! enabled) serves all three domains.
6//!
7//! # Architecture
8//!
9//! ```text
10//! ┌──────────┐
11//! │  CoreDb   │  owns SqlitePool (max 2 conns, WAL, FK ON)
12//! └────┬─────┘
13//!      │ pool.clone()
14//!      ├──────────────► SessionManager::from_pool(pool)
15//!      ├──────────────► CronStore::from_pool(pool)
16//!      └──────────────► AuthStore::from_pool(pool)
17//! ```
18//!
19//! # Databases kept separate
20//!
21//! - **memory.db** — FTS5 + vector blobs, bulk reindex I/O, different access pattern
22//! - **vault.db** — AES-256-GCM encrypted, optional (needs `.vault_key`), isolated security boundary
23//!
24//! # Usage
25//!
26//! ```no_run
27//! # async fn example() -> starpod_core::Result<()> {
28//! use starpod_db::CoreDb;
29//!
30//! let db = CoreDb::new(std::path::Path::new(".starpod/db")).await?;
31//! // Pass db.pool().clone() to SessionManager, CronStore, AuthStore
32//! # Ok(())
33//! # }
34//! ```
35
36use std::path::Path;
37use std::str::FromStr;
38
39use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
40use sqlx::SqlitePool;
41use tracing::{debug, info};
42
43use starpod_core::{Result, StarpodError};
44
45/// Unified database for sessions, cron, and auth.
46///
47/// Owns a single `SqlitePool` backed by `core.db`. Individual stores
48/// (`SessionManager`, `CronStore`, `AuthStore`) receive a clone of the
49/// pool via `from_pool()` instead of opening their own connections.
50///
51/// The pool is configured with:
52/// - **WAL journal mode** — concurrent readers don't block writers
53/// - **Foreign keys ON** — referential integrity across all tables
54/// - **2 max connections** — one writer + one reader; SQLite serialises
55///   writes anyway, so more connections just waste memory (~2 MB page
56///   cache each) and cause lock contention on small VMs
57/// - **`busy_timeout = 5000`** — wait up to 5 s for a lock instead of
58///   returning SQLITE_BUSY immediately
59/// - **`synchronous = NORMAL`** — safe with WAL, avoids fsync per commit
60/// - **`cache_size = -2000`** — 2 MB page cache per connection (default)
61pub struct CoreDb {
62    pool: SqlitePool,
63}
64
65impl CoreDb {
66    /// Open (or create) `core.db` inside `db_dir`.
67    ///
68    /// Runs all migrations from `./migrations`. If migrations fail due to a
69    /// checksum mismatch or a removed migration (common during development
70    /// when migration files are edited in-place), the database is deleted
71    /// and recreated from scratch.
72    pub async fn new(db_dir: &Path) -> Result<Self> {
73        std::fs::create_dir_all(db_dir)?;
74
75        let db_path = db_dir.join("core.db");
76
77        // Try to open and migrate; on schema mismatch, recreate from scratch.
78        match Self::open_and_migrate(&db_path).await {
79            Ok(pool) => {
80                debug!("core.db ready at {}", db_path.display());
81                Ok(Self { pool })
82            }
83            Err(e) => {
84                let msg = e.to_string();
85                let is_schema_mismatch = msg.contains("previously applied but is missing")
86                    || msg.contains("checksum mismatch");
87
88                if !is_schema_mismatch {
89                    return Err(e);
90                }
91
92                info!("Migration schema changed — recreating core.db");
93                // Remove db + WAL/SHM files
94                let db_str = db_path.display().to_string();
95                let _ = std::fs::remove_file(&db_path);
96                let _ = std::fs::remove_file(format!("{db_str}-wal"));
97                let _ = std::fs::remove_file(format!("{db_str}-shm"));
98
99                let pool = Self::open_and_migrate(&db_path).await?;
100                debug!("core.db recreated at {}", db_path.display());
101
102                Ok(Self { pool })
103            }
104        }
105    }
106
107    /// Open (or create) the database file and run migrations.
108    async fn open_and_migrate(db_path: &Path) -> Result<SqlitePool> {
109        let opts =
110            SqliteConnectOptions::from_str(&format!("sqlite://{}?mode=rwc", db_path.display()))
111                .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?
112                .pragma("journal_mode", "WAL")
113                .pragma("foreign_keys", "ON")
114                .pragma("busy_timeout", "5000")
115                .pragma("synchronous", "NORMAL");
116
117        let pool = SqlitePoolOptions::new()
118            .max_connections(2)
119            .connect_with(opts)
120            .await
121            .map_err(|e| StarpodError::Database(format!("Failed to open core db: {}", e)))?;
122
123        sqlx::migrate!("./migrations")
124            .run(&pool)
125            .await
126            .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
127
128        Ok(pool)
129    }
130
131    /// Create an in-memory `CoreDb` for testing.
132    ///
133    /// Runs all migrations on a shared in-memory database. Each call
134    /// returns a fresh, empty database.
135    pub async fn in_memory() -> Result<Self> {
136        let opts = SqliteConnectOptions::from_str("sqlite::memory:")
137            .map_err(|e| StarpodError::Database(format!("Invalid memory DB: {}", e)))?
138            .pragma("foreign_keys", "ON");
139
140        let pool = SqlitePoolOptions::new()
141            .max_connections(1)
142            .connect_with(opts)
143            .await
144            .map_err(|e| StarpodError::Database(format!("Failed to open in-memory db: {}", e)))?;
145
146        sqlx::migrate!("./migrations")
147            .run(&pool)
148            .await
149            .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
150
151        Ok(Self { pool })
152    }
153
154    /// Get a reference to the shared connection pool.
155    pub fn pool(&self) -> &SqlitePool {
156        &self.pool
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    // ── Basic lifecycle ─────────────────────────────────────────────
165
166    #[tokio::test]
167    async fn in_memory_creates_all_tables() {
168        let db = CoreDb::in_memory().await.unwrap();
169        let pool = db.pool();
170
171        // Auth tables
172        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
173            .fetch_one(pool)
174            .await
175            .unwrap();
176        assert_eq!(row.0, 0);
177
178        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
179            .fetch_one(pool)
180            .await
181            .unwrap();
182        assert_eq!(row.0, 0);
183
184        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
185            .fetch_one(pool)
186            .await
187            .unwrap();
188        assert_eq!(row.0, 0);
189
190        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM auth_audit_log")
191            .fetch_one(pool)
192            .await
193            .unwrap();
194        assert_eq!(row.0, 0);
195
196        // Session tables
197        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_metadata")
198            .fetch_one(pool)
199            .await
200            .unwrap();
201        assert_eq!(row.0, 0);
202
203        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
204            .fetch_one(pool)
205            .await
206            .unwrap();
207        assert_eq!(row.0, 0);
208
209        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM usage_stats")
210            .fetch_one(pool)
211            .await
212            .unwrap();
213        assert_eq!(row.0, 0);
214
215        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
216            .fetch_one(pool)
217            .await
218            .unwrap();
219        assert_eq!(row.0, 0);
220
221        // Cron tables
222        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_jobs")
223            .fetch_one(pool)
224            .await
225            .unwrap();
226        assert_eq!(row.0, 0);
227
228        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
229            .fetch_one(pool)
230            .await
231            .unwrap();
232        assert_eq!(row.0, 0);
233    }
234
235    #[tokio::test]
236    async fn on_disk_creates_core_db() {
237        let tmp = tempfile::tempdir().unwrap();
238        let db = CoreDb::new(tmp.path()).await.unwrap();
239
240        assert!(tmp.path().join("core.db").exists());
241
242        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
243            .fetch_one(db.pool())
244            .await
245            .unwrap();
246        assert_eq!(row.0, 0);
247    }
248
249    #[tokio::test]
250    async fn on_disk_creates_parent_dirs() {
251        let tmp = tempfile::tempdir().unwrap();
252        let nested = tmp.path().join("deep").join("nested").join("db");
253        let db = CoreDb::new(&nested).await.unwrap();
254
255        assert!(nested.join("core.db").exists());
256        drop(db);
257    }
258
259    #[tokio::test]
260    async fn reopen_is_idempotent() {
261        let tmp = tempfile::tempdir().unwrap();
262
263        // First open — creates the DB
264        let db1 = CoreDb::new(tmp.path()).await.unwrap();
265        sqlx::query(
266            "INSERT INTO users (id, email, display_name, role, is_active, created_at, updated_at) \
267             VALUES ('u1', 'a@b.com', 'A', 'admin', 1, '2024-01-01', '2024-01-01')",
268        )
269        .execute(db1.pool())
270        .await
271        .unwrap();
272        drop(db1);
273
274        // Second open — should find existing data, not recreate
275        let db2 = CoreDb::new(tmp.path()).await.unwrap();
276        let row: (String,) = sqlx::query_as("SELECT email FROM users WHERE id = 'u1'")
277            .fetch_one(db2.pool())
278            .await
279            .unwrap();
280        assert_eq!(row.0, "a@b.com");
281    }
282
283    // ── Foreign key enforcement ─────────────────────────────────────
284
285    #[tokio::test]
286    async fn fk_rejects_invalid_api_key_user() {
287        let db = CoreDb::in_memory().await.unwrap();
288
289        let result = sqlx::query(
290            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
291             VALUES ('k1', 'nonexistent', 'sp_', 'hash', '2024-01-01')",
292        )
293        .execute(db.pool())
294        .await;
295
296        assert!(
297            result.is_err(),
298            "FK should reject api_key with invalid user_id"
299        );
300    }
301
302    #[tokio::test]
303    async fn fk_rejects_invalid_telegram_link_user() {
304        let db = CoreDb::in_memory().await.unwrap();
305
306        let result = sqlx::query(
307            "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
308             VALUES (123, 'nonexistent', 'bob', '2024-01-01')",
309        )
310        .execute(db.pool())
311        .await;
312
313        assert!(
314            result.is_err(),
315            "FK should reject telegram_link with invalid user_id"
316        );
317    }
318
319    #[tokio::test]
320    async fn fk_rejects_invalid_session_message() {
321        let db = CoreDb::in_memory().await.unwrap();
322
323        let result = sqlx::query(
324            "INSERT INTO session_messages (session_id, role, content, timestamp) \
325             VALUES ('nonexistent', 'user', 'hello', '2024-01-01')",
326        )
327        .execute(db.pool())
328        .await;
329
330        assert!(
331            result.is_err(),
332            "FK should reject message with invalid session_id"
333        );
334    }
335
336    #[tokio::test]
337    async fn fk_rejects_invalid_cron_run_job() {
338        let db = CoreDb::in_memory().await.unwrap();
339
340        let result = sqlx::query(
341            "INSERT INTO cron_runs (id, job_id, started_at, status) \
342             VALUES ('r1', 'nonexistent', 1000, 'pending')",
343        )
344        .execute(db.pool())
345        .await;
346
347        assert!(
348            result.is_err(),
349            "FK should reject cron_run with invalid job_id"
350        );
351    }
352
353    // ── CASCADE deletes ─────────────────────────────────────────────
354
355    #[tokio::test]
356    async fn cascade_delete_user_removes_api_keys() {
357        let db = CoreDb::in_memory().await.unwrap();
358        let pool = db.pool();
359
360        sqlx::query(
361            "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
362             VALUES ('u1', 'a@b.com', 'admin', 1, '2024-01-01', '2024-01-01')",
363        )
364        .execute(pool)
365        .await
366        .unwrap();
367
368        sqlx::query(
369            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
370             VALUES ('k1', 'u1', 'sp_', 'hash1', '2024-01-01')",
371        )
372        .execute(pool)
373        .await
374        .unwrap();
375
376        sqlx::query(
377            "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
378             VALUES ('k2', 'u1', 'sp_', 'hash2', '2024-01-01')",
379        )
380        .execute(pool)
381        .await
382        .unwrap();
383
384        // Delete user
385        sqlx::query("DELETE FROM users WHERE id = 'u1'")
386            .execute(pool)
387            .await
388            .unwrap();
389
390        // API keys should be gone
391        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
392            .fetch_one(pool)
393            .await
394            .unwrap();
395        assert_eq!(row.0, 0);
396    }
397
398    #[tokio::test]
399    async fn cascade_delete_user_removes_telegram_links() {
400        let db = CoreDb::in_memory().await.unwrap();
401        let pool = db.pool();
402
403        sqlx::query(
404            "INSERT INTO users (id, role, is_active, created_at, updated_at) \
405             VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')",
406        )
407        .execute(pool)
408        .await
409        .unwrap();
410
411        sqlx::query(
412            "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
413             VALUES (999, 'u1', 'bob', '2024-01-01')",
414        )
415        .execute(pool)
416        .await
417        .unwrap();
418
419        sqlx::query("DELETE FROM users WHERE id = 'u1'")
420            .execute(pool)
421            .await
422            .unwrap();
423
424        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
425            .fetch_one(pool)
426            .await
427            .unwrap();
428        assert_eq!(row.0, 0);
429    }
430
431    #[tokio::test]
432    async fn cascade_delete_session_removes_messages_and_compaction() {
433        let db = CoreDb::in_memory().await.unwrap();
434        let pool = db.pool();
435
436        sqlx::query(
437            "INSERT INTO session_metadata (id, created_at, last_message_at) \
438             VALUES ('s1', '2024-01-01', '2024-01-01')",
439        )
440        .execute(pool)
441        .await
442        .unwrap();
443
444        sqlx::query(
445            "INSERT INTO session_messages (session_id, role, content, timestamp) \
446             VALUES ('s1', 'user', 'hi', '2024-01-01')",
447        )
448        .execute(pool)
449        .await
450        .unwrap();
451
452        sqlx::query(
453            "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary) \
454             VALUES ('s1', '2024-01-01', 'auto', 1000, 'summary')",
455        )
456        .execute(pool)
457        .await
458        .unwrap();
459
460        sqlx::query("DELETE FROM session_metadata WHERE id = 's1'")
461            .execute(pool)
462            .await
463            .unwrap();
464
465        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
466            .fetch_one(pool)
467            .await
468            .unwrap();
469        assert_eq!(row.0, 0);
470
471        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
472            .fetch_one(pool)
473            .await
474            .unwrap();
475        assert_eq!(row.0, 0);
476    }
477
478    #[tokio::test]
479    async fn cascade_delete_cron_job_removes_runs() {
480        let db = CoreDb::in_memory().await.unwrap();
481        let pool = db.pool();
482
483        sqlx::query(
484            "INSERT INTO cron_jobs (id, name, prompt, schedule_type, schedule_value, created_at) \
485             VALUES ('j1', 'test', 'do stuff', 'interval', '60000', 1000)",
486        )
487        .execute(pool)
488        .await
489        .unwrap();
490
491        sqlx::query(
492            "INSERT INTO cron_runs (id, job_id, started_at, status) \
493             VALUES ('r1', 'j1', 2000, 'success')",
494        )
495        .execute(pool)
496        .await
497        .unwrap();
498
499        sqlx::query("DELETE FROM cron_jobs WHERE id = 'j1'")
500            .execute(pool)
501            .await
502            .unwrap();
503
504        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
505            .fetch_one(pool)
506            .await
507            .unwrap();
508        assert_eq!(row.0, 0);
509    }
510
511    // ── Cross-domain queries (the whole point of consolidation) ─────
512
513    #[tokio::test]
514    async fn cross_domain_join_sessions_with_usage_by_user() {
515        let db = CoreDb::in_memory().await.unwrap();
516        let pool = db.pool();
517
518        // Create a user
519        sqlx::query(
520            "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
521             VALUES ('u1', 'alice@test.com', 'admin', 1, '2024-01-01', '2024-01-01')",
522        )
523        .execute(pool)
524        .await
525        .unwrap();
526
527        // Create sessions for this user
528        sqlx::query(
529            "INSERT INTO session_metadata (id, created_at, last_message_at, user_id) \
530             VALUES ('s1', '2024-01-01', '2024-01-01', 'u1')",
531        )
532        .execute(pool)
533        .await
534        .unwrap();
535
536        // Record usage
537        sqlx::query(
538            "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cost_usd, timestamp, user_id) \
539             VALUES ('s1', 1, 100, 200, 0.01, '2024-01-01', 'u1')"
540        ).execute(pool).await.unwrap();
541
542        // Cross-domain query: total cost per user (joins users + usage_stats)
543        let row: (String, f64) = sqlx::query_as(
544            "SELECT u.email, SUM(us.cost_usd) as total_cost \
545             FROM users u \
546             JOIN usage_stats us ON us.user_id = u.id \
547             GROUP BY u.id",
548        )
549        .fetch_one(pool)
550        .await
551        .unwrap();
552
553        assert_eq!(row.0, "alice@test.com");
554        assert!((row.1 - 0.01).abs() < 0.001);
555    }
556
557    #[tokio::test]
558    async fn pool_clone_shares_state() {
559        let db = CoreDb::in_memory().await.unwrap();
560
561        // Insert on original pool
562        sqlx::query(
563            "INSERT INTO users (id, role, is_active, created_at, updated_at) \
564             VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')",
565        )
566        .execute(db.pool())
567        .await
568        .unwrap();
569
570        // Read from cloned pool (simulates what stores do)
571        let pool2 = db.pool().clone();
572        let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
573            .fetch_one(&pool2)
574            .await
575            .unwrap();
576        assert_eq!(row.0, 1);
577    }
578}