1mod migrate;
44
45use std::path::Path;
46use std::str::FromStr;
47
48use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
49use sqlx::SqlitePool;
50use tracing::{debug, info};
51
52use starpod_core::{StarpodError, Result};
53
54pub struct CoreDb {
65 pool: SqlitePool,
66}
67
68impl CoreDb {
69 pub async fn new(db_dir: &Path) -> Result<Self> {
75 std::fs::create_dir_all(db_dir)?;
76
77 let db_path = db_dir.join("core.db");
78 let opts = SqliteConnectOptions::from_str(
79 &format!("sqlite://{}?mode=rwc", db_path.display()),
80 )
81 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?
82 .pragma("journal_mode", "WAL")
83 .pragma("foreign_keys", "ON");
84
85 let pool = SqlitePoolOptions::new()
86 .max_connections(10)
87 .connect_with(opts)
88 .await
89 .map_err(|e| StarpodError::Database(format!("Failed to open core db: {}", e)))?;
90
91 sqlx::migrate!("./migrations")
92 .run(&pool)
93 .await
94 .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
95
96 debug!("core.db ready at {}", db_path.display());
97
98 if migrate::has_legacy_dbs(db_dir) {
100 info!("Legacy database files detected — migrating to core.db");
101 migrate::migrate_legacy_dbs(&pool, db_dir).await?;
102 }
103
104 Ok(Self { pool })
105 }
106
107 pub async fn in_memory() -> Result<Self> {
112 let opts = SqliteConnectOptions::from_str("sqlite::memory:")
113 .map_err(|e| StarpodError::Database(format!("Invalid memory DB: {}", e)))?
114 .pragma("foreign_keys", "ON");
115
116 let pool = SqlitePoolOptions::new()
117 .max_connections(1)
118 .connect_with(opts)
119 .await
120 .map_err(|e| StarpodError::Database(format!("Failed to open in-memory db: {}", e)))?;
121
122 sqlx::migrate!("./migrations")
123 .run(&pool)
124 .await
125 .map_err(|e| StarpodError::Database(format!("Core migration failed: {}", e)))?;
126
127 Ok(Self { pool })
128 }
129
130 pub fn pool(&self) -> &SqlitePool {
132 &self.pool
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[tokio::test]
143 async fn in_memory_creates_all_tables() {
144 let db = CoreDb::in_memory().await.unwrap();
145 let pool = db.pool();
146
147 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
149 .fetch_one(pool).await.unwrap();
150 assert_eq!(row.0, 0);
151
152 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
153 .fetch_one(pool).await.unwrap();
154 assert_eq!(row.0, 0);
155
156 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
157 .fetch_one(pool).await.unwrap();
158 assert_eq!(row.0, 0);
159
160 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM auth_audit_log")
161 .fetch_one(pool).await.unwrap();
162 assert_eq!(row.0, 0);
163
164 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_metadata")
166 .fetch_one(pool).await.unwrap();
167 assert_eq!(row.0, 0);
168
169 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
170 .fetch_one(pool).await.unwrap();
171 assert_eq!(row.0, 0);
172
173 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM usage_stats")
174 .fetch_one(pool).await.unwrap();
175 assert_eq!(row.0, 0);
176
177 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
178 .fetch_one(pool).await.unwrap();
179 assert_eq!(row.0, 0);
180
181 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_jobs")
183 .fetch_one(pool).await.unwrap();
184 assert_eq!(row.0, 0);
185
186 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
187 .fetch_one(pool).await.unwrap();
188 assert_eq!(row.0, 0);
189 }
190
191 #[tokio::test]
192 async fn on_disk_creates_core_db() {
193 let tmp = tempfile::tempdir().unwrap();
194 let db = CoreDb::new(tmp.path()).await.unwrap();
195
196 assert!(tmp.path().join("core.db").exists());
197
198 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
199 .fetch_one(db.pool()).await.unwrap();
200 assert_eq!(row.0, 0);
201 }
202
203 #[tokio::test]
204 async fn on_disk_creates_parent_dirs() {
205 let tmp = tempfile::tempdir().unwrap();
206 let nested = tmp.path().join("deep").join("nested").join("db");
207 let db = CoreDb::new(&nested).await.unwrap();
208
209 assert!(nested.join("core.db").exists());
210 drop(db);
211 }
212
213 #[tokio::test]
214 async fn reopen_is_idempotent() {
215 let tmp = tempfile::tempdir().unwrap();
216
217 let db1 = CoreDb::new(tmp.path()).await.unwrap();
219 sqlx::query(
220 "INSERT INTO users (id, email, display_name, role, is_active, created_at, updated_at) \
221 VALUES ('u1', 'a@b.com', 'A', 'admin', 1, '2024-01-01', '2024-01-01')"
222 ).execute(db1.pool()).await.unwrap();
223 drop(db1);
224
225 let db2 = CoreDb::new(tmp.path()).await.unwrap();
227 let row: (String,) = sqlx::query_as("SELECT email FROM users WHERE id = 'u1'")
228 .fetch_one(db2.pool()).await.unwrap();
229 assert_eq!(row.0, "a@b.com");
230 }
231
232 #[tokio::test]
235 async fn fk_rejects_invalid_api_key_user() {
236 let db = CoreDb::in_memory().await.unwrap();
237
238 let result = sqlx::query(
239 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
240 VALUES ('k1', 'nonexistent', 'sp_', 'hash', '2024-01-01')"
241 ).execute(db.pool()).await;
242
243 assert!(result.is_err(), "FK should reject api_key with invalid user_id");
244 }
245
246 #[tokio::test]
247 async fn fk_rejects_invalid_telegram_link_user() {
248 let db = CoreDb::in_memory().await.unwrap();
249
250 let result = sqlx::query(
251 "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
252 VALUES (123, 'nonexistent', 'bob', '2024-01-01')"
253 ).execute(db.pool()).await;
254
255 assert!(result.is_err(), "FK should reject telegram_link with invalid user_id");
256 }
257
258 #[tokio::test]
259 async fn fk_rejects_invalid_session_message() {
260 let db = CoreDb::in_memory().await.unwrap();
261
262 let result = sqlx::query(
263 "INSERT INTO session_messages (session_id, role, content, timestamp) \
264 VALUES ('nonexistent', 'user', 'hello', '2024-01-01')"
265 ).execute(db.pool()).await;
266
267 assert!(result.is_err(), "FK should reject message with invalid session_id");
268 }
269
270 #[tokio::test]
271 async fn fk_rejects_invalid_cron_run_job() {
272 let db = CoreDb::in_memory().await.unwrap();
273
274 let result = sqlx::query(
275 "INSERT INTO cron_runs (id, job_id, started_at, status) \
276 VALUES ('r1', 'nonexistent', 1000, 'pending')"
277 ).execute(db.pool()).await;
278
279 assert!(result.is_err(), "FK should reject cron_run with invalid job_id");
280 }
281
282 #[tokio::test]
285 async fn cascade_delete_user_removes_api_keys() {
286 let db = CoreDb::in_memory().await.unwrap();
287 let pool = db.pool();
288
289 sqlx::query(
290 "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
291 VALUES ('u1', 'a@b.com', 'admin', 1, '2024-01-01', '2024-01-01')"
292 ).execute(pool).await.unwrap();
293
294 sqlx::query(
295 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
296 VALUES ('k1', 'u1', 'sp_', 'hash1', '2024-01-01')"
297 ).execute(pool).await.unwrap();
298
299 sqlx::query(
300 "INSERT INTO api_keys (id, user_id, prefix, key_hash, created_at) \
301 VALUES ('k2', 'u1', 'sp_', 'hash2', '2024-01-01')"
302 ).execute(pool).await.unwrap();
303
304 sqlx::query("DELETE FROM users WHERE id = 'u1'")
306 .execute(pool).await.unwrap();
307
308 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM api_keys")
310 .fetch_one(pool).await.unwrap();
311 assert_eq!(row.0, 0);
312 }
313
314 #[tokio::test]
315 async fn cascade_delete_user_removes_telegram_links() {
316 let db = CoreDb::in_memory().await.unwrap();
317 let pool = db.pool();
318
319 sqlx::query(
320 "INSERT INTO users (id, role, is_active, created_at, updated_at) \
321 VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')"
322 ).execute(pool).await.unwrap();
323
324 sqlx::query(
325 "INSERT INTO telegram_links (telegram_id, user_id, username, linked_at) \
326 VALUES (999, 'u1', 'bob', '2024-01-01')"
327 ).execute(pool).await.unwrap();
328
329 sqlx::query("DELETE FROM users WHERE id = 'u1'")
330 .execute(pool).await.unwrap();
331
332 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM telegram_links")
333 .fetch_one(pool).await.unwrap();
334 assert_eq!(row.0, 0);
335 }
336
337 #[tokio::test]
338 async fn cascade_delete_session_removes_messages_and_compaction() {
339 let db = CoreDb::in_memory().await.unwrap();
340 let pool = db.pool();
341
342 sqlx::query(
343 "INSERT INTO session_metadata (id, created_at, last_message_at) \
344 VALUES ('s1', '2024-01-01', '2024-01-01')"
345 ).execute(pool).await.unwrap();
346
347 sqlx::query(
348 "INSERT INTO session_messages (session_id, role, content, timestamp) \
349 VALUES ('s1', 'user', 'hi', '2024-01-01')"
350 ).execute(pool).await.unwrap();
351
352 sqlx::query(
353 "INSERT INTO compaction_log (session_id, timestamp, trigger, pre_tokens, summary) \
354 VALUES ('s1', '2024-01-01', 'auto', 1000, 'summary')"
355 ).execute(pool).await.unwrap();
356
357 sqlx::query("DELETE FROM session_metadata WHERE id = 's1'")
358 .execute(pool).await.unwrap();
359
360 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM session_messages")
361 .fetch_one(pool).await.unwrap();
362 assert_eq!(row.0, 0);
363
364 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM compaction_log")
365 .fetch_one(pool).await.unwrap();
366 assert_eq!(row.0, 0);
367 }
368
369 #[tokio::test]
370 async fn cascade_delete_cron_job_removes_runs() {
371 let db = CoreDb::in_memory().await.unwrap();
372 let pool = db.pool();
373
374 sqlx::query(
375 "INSERT INTO cron_jobs (id, name, prompt, schedule_type, schedule_value, created_at) \
376 VALUES ('j1', 'test', 'do stuff', 'interval', '60000', 1000)"
377 ).execute(pool).await.unwrap();
378
379 sqlx::query(
380 "INSERT INTO cron_runs (id, job_id, started_at, status) \
381 VALUES ('r1', 'j1', 2000, 'success')"
382 ).execute(pool).await.unwrap();
383
384 sqlx::query("DELETE FROM cron_jobs WHERE id = 'j1'")
385 .execute(pool).await.unwrap();
386
387 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM cron_runs")
388 .fetch_one(pool).await.unwrap();
389 assert_eq!(row.0, 0);
390 }
391
392 #[tokio::test]
395 async fn cross_domain_join_sessions_with_usage_by_user() {
396 let db = CoreDb::in_memory().await.unwrap();
397 let pool = db.pool();
398
399 sqlx::query(
401 "INSERT INTO users (id, email, role, is_active, created_at, updated_at) \
402 VALUES ('u1', 'alice@test.com', 'admin', 1, '2024-01-01', '2024-01-01')"
403 ).execute(pool).await.unwrap();
404
405 sqlx::query(
407 "INSERT INTO session_metadata (id, created_at, last_message_at, user_id) \
408 VALUES ('s1', '2024-01-01', '2024-01-01', 'u1')"
409 ).execute(pool).await.unwrap();
410
411 sqlx::query(
413 "INSERT INTO usage_stats (session_id, turn, input_tokens, output_tokens, cost_usd, timestamp, user_id) \
414 VALUES ('s1', 1, 100, 200, 0.01, '2024-01-01', 'u1')"
415 ).execute(pool).await.unwrap();
416
417 let row: (String, f64) = sqlx::query_as(
419 "SELECT u.email, SUM(us.cost_usd) as total_cost \
420 FROM users u \
421 JOIN usage_stats us ON us.user_id = u.id \
422 GROUP BY u.id"
423 ).fetch_one(pool).await.unwrap();
424
425 assert_eq!(row.0, "alice@test.com");
426 assert!((row.1 - 0.01).abs() < 0.001);
427 }
428
429 #[tokio::test]
430 async fn pool_clone_shares_state() {
431 let db = CoreDb::in_memory().await.unwrap();
432
433 sqlx::query(
435 "INSERT INTO users (id, role, is_active, created_at, updated_at) \
436 VALUES ('u1', 'admin', 1, '2024-01-01', '2024-01-01')"
437 ).execute(db.pool()).await.unwrap();
438
439 let pool2 = db.pool().clone();
441 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM users")
442 .fetch_one(&pool2).await.unwrap();
443 assert_eq!(row.0, 1);
444 }
445}