1use std::path::{Path, PathBuf};
7use std::time::Instant;
8
9use chrono::Utc;
10use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
11use sqlx::Row;
12
13use super::DbPool;
14use crate::error::StorageError;
15
16#[derive(Debug, Clone)]
18pub struct BackupResult {
19 pub path: PathBuf,
21 pub size_bytes: u64,
23 pub duration_ms: u64,
25}
26
27#[derive(Debug, Clone)]
29pub struct BackupInfo {
30 pub path: PathBuf,
32 pub size_bytes: u64,
34 pub timestamp: Option<String>,
36}
37
38#[derive(Debug, Clone)]
40pub struct ValidationResult {
41 pub valid: bool,
43 pub tables: Vec<String>,
45 pub messages: Vec<String>,
47}
48
49pub async fn create_backup(pool: &DbPool, backup_dir: &Path) -> Result<BackupResult, StorageError> {
54 create_backup_with_prefix(pool, backup_dir, "tuitbot").await
55}
56
57async fn create_backup_with_prefix(
59 pool: &DbPool,
60 backup_dir: &Path,
61 prefix: &str,
62) -> Result<BackupResult, StorageError> {
63 std::fs::create_dir_all(backup_dir).map_err(|e| StorageError::Connection {
64 source: sqlx::Error::Configuration(
65 format!(
66 "failed to create backup directory {}: {e}",
67 backup_dir.display()
68 )
69 .into(),
70 ),
71 })?;
72
73 let timestamp = Utc::now().format("%Y%m%d_%H%M%S");
74 let filename = format!("{prefix}_{timestamp}.db");
75 let backup_path = backup_dir.join(&filename);
76
77 let start = Instant::now();
78
79 let path_str = backup_path
80 .to_str()
81 .ok_or_else(|| StorageError::Connection {
82 source: sqlx::Error::Configuration("backup path is not valid UTF-8".into()),
83 })?
84 .to_string();
85
86 let query = format!("VACUUM INTO '{}'", path_str.replace('\'', "''"));
88 sqlx::query(&query)
89 .execute(pool)
90 .await
91 .map_err(|e| StorageError::Query { source: e })?;
92
93 let duration_ms = start.elapsed().as_millis() as u64;
94
95 let metadata = std::fs::metadata(&backup_path).map_err(|e| StorageError::Connection {
96 source: sqlx::Error::Configuration(format!("failed to stat backup file: {e}").into()),
97 })?;
98
99 Ok(BackupResult {
100 path: backup_path,
101 size_bytes: metadata.len(),
102 duration_ms,
103 })
104}
105
106pub async fn validate_backup(backup_path: &Path) -> Result<ValidationResult, StorageError> {
108 if !backup_path.exists() {
109 return Ok(ValidationResult {
110 valid: false,
111 tables: vec![],
112 messages: vec![format!("File not found: {}", backup_path.display())],
113 });
114 }
115
116 let path_str = backup_path.to_string_lossy();
117 let options = SqliteConnectOptions::new()
118 .filename(&*path_str)
119 .read_only(true);
120
121 let pool = SqlitePoolOptions::new()
122 .max_connections(1)
123 .connect_with(options)
124 .await
125 .map_err(|e| StorageError::Connection { source: e })?;
126
127 let mut messages = Vec::new();
128
129 let rows = sqlx::query(
131 "SELECT name FROM sqlite_master WHERE type='table' \
132 AND name NOT LIKE 'sqlite_%' AND name != '_sqlx_migrations' \
133 ORDER BY name",
134 )
135 .fetch_all(&pool)
136 .await
137 .map_err(|e| StorageError::Query { source: e })?;
138
139 let tables: Vec<String> = rows.iter().map(|r| r.get("name")).collect();
140
141 let expected = [
143 "action_log",
144 "discovered_tweets",
145 "replies_sent",
146 "rate_limits",
147 ];
148 let mut missing = Vec::new();
149 for table in &expected {
150 if !tables.iter().any(|t| t == table) {
151 missing.push(*table);
152 }
153 }
154
155 let valid = missing.is_empty() && !tables.is_empty();
156
157 if valid {
158 messages.push(format!("Valid backup with {} tables", tables.len()));
159 } else if tables.is_empty() {
160 messages.push("No tables found in backup".to_string());
161 } else {
162 messages.push(format!("Missing expected tables: {}", missing.join(", ")));
163 }
164
165 let integrity: String = sqlx::query_scalar("PRAGMA integrity_check")
167 .fetch_one(&pool)
168 .await
169 .unwrap_or_else(|_| "error".to_string());
170
171 if integrity != "ok" {
172 messages.push(format!("Integrity check failed: {integrity}"));
173 return Ok(ValidationResult {
174 valid: false,
175 tables,
176 messages,
177 });
178 }
179
180 pool.close().await;
181
182 Ok(ValidationResult {
183 valid,
184 tables,
185 messages,
186 })
187}
188
189pub async fn restore_from_backup(
195 backup_path: &Path,
196 target_path: &Path,
197) -> Result<(), StorageError> {
198 let validation = validate_backup(backup_path).await?;
200 if !validation.valid {
201 return Err(StorageError::Connection {
202 source: sqlx::Error::Configuration(
203 format!(
204 "Backup validation failed: {}",
205 validation.messages.join("; ")
206 )
207 .into(),
208 ),
209 });
210 }
211
212 if target_path.exists() {
214 let parent = target_path.parent().unwrap_or_else(|| Path::new("."));
215 let safety_name = format!("pre_restore_{}.db", Utc::now().format("%Y%m%d_%H%M%S"));
216 let safety_path = parent.join(safety_name);
217 std::fs::copy(target_path, &safety_path).map_err(|e| StorageError::Connection {
218 source: sqlx::Error::Configuration(
219 format!("Failed to create safety backup: {e}").into(),
220 ),
221 })?;
222 tracing::info!(
223 path = %safety_path.display(),
224 "Created safety backup of current database"
225 );
226 }
227
228 let parent = target_path.parent().unwrap_or_else(|| Path::new("."));
230 let temp_path = parent.join(format!(
231 ".tuitbot_restore_{}.tmp",
232 Utc::now().timestamp_millis()
233 ));
234
235 std::fs::copy(backup_path, &temp_path).map_err(|e| StorageError::Connection {
236 source: sqlx::Error::Configuration(format!("Failed to copy backup: {e}").into()),
237 })?;
238
239 std::fs::rename(&temp_path, target_path).map_err(|e| StorageError::Connection {
240 source: sqlx::Error::Configuration(format!("Failed to rename temp to target: {e}").into()),
241 })?;
242
243 let wal_path = target_path.with_extension("db-wal");
245 let shm_path = target_path.with_extension("db-shm");
246 let _ = std::fs::remove_file(wal_path);
247 let _ = std::fs::remove_file(shm_path);
248
249 Ok(())
250}
251
252pub fn list_backups(backup_dir: &Path) -> Vec<BackupInfo> {
254 let mut backups = Vec::new();
255
256 let entries = match std::fs::read_dir(backup_dir) {
257 Ok(e) => e,
258 Err(_) => return backups,
259 };
260
261 for entry in entries.flatten() {
262 let path = entry.path();
263 let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
264
265 if !name.starts_with("tuitbot_") || !name.ends_with(".db") {
266 continue;
267 }
268
269 let size_bytes = std::fs::metadata(&path).map(|m| m.len()).unwrap_or(0);
270
271 let timestamp = name
273 .strip_prefix("tuitbot_")
274 .and_then(|s| s.strip_suffix(".db"))
275 .map(|s| s.to_string());
276
277 backups.push(BackupInfo {
278 path,
279 size_bytes,
280 timestamp,
281 });
282 }
283
284 backups.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
286 backups
287}
288
289pub fn prune_backups(backup_dir: &Path, keep: usize) -> Result<u32, StorageError> {
293 let backups = list_backups(backup_dir);
294 let mut deleted = 0u32;
295
296 if backups.len() <= keep {
297 return Ok(0);
298 }
299
300 for backup in backups.iter().skip(keep) {
301 if let Err(e) = std::fs::remove_file(&backup.path) {
302 tracing::warn!(
303 path = %backup.path.display(),
304 error = %e,
305 "Failed to prune backup"
306 );
307 } else {
308 deleted += 1;
309 }
310 }
311
312 Ok(deleted)
313}
314
315pub async fn preflight_migration_backup(db_path: &Path) -> Result<Option<PathBuf>, StorageError> {
321 let metadata = match std::fs::metadata(db_path) {
323 Ok(m) if m.len() > 0 => m,
324 _ => return Ok(None),
325 };
326
327 tracing::info!(
328 db = %db_path.display(),
329 size_bytes = metadata.len(),
330 "Creating pre-migration backup"
331 );
332
333 let path_str = db_path.to_string_lossy();
335 let options = SqliteConnectOptions::new()
336 .filename(&*path_str)
337 .read_only(true);
338
339 let pool = SqlitePoolOptions::new()
340 .max_connections(1)
341 .connect_with(options)
342 .await
343 .map_err(|e| StorageError::Connection { source: e })?;
344
345 let backup_dir = db_path
346 .parent()
347 .unwrap_or_else(|| Path::new("."))
348 .join("backups");
349
350 let result = create_backup_with_prefix(&pool, &backup_dir, "pre_migration").await?;
351
352 pool.close().await;
353
354 tracing::info!(
355 path = %result.path.display(),
356 size_bytes = result.size_bytes,
357 duration_ms = result.duration_ms,
358 "Pre-migration backup complete"
359 );
360
361 prune_preflight_backups(&backup_dir, 3);
363
364 Ok(Some(result.path))
365}
366
367fn prune_preflight_backups(backup_dir: &Path, keep: usize) {
369 let entries = match std::fs::read_dir(backup_dir) {
370 Ok(e) => e,
371 Err(_) => return,
372 };
373
374 let mut pre_migration: Vec<PathBuf> = entries
375 .flatten()
376 .map(|e| e.path())
377 .filter(|p| {
378 p.file_name()
379 .and_then(|n| n.to_str())
380 .is_some_and(|n| n.starts_with("pre_migration_") && n.ends_with(".db"))
381 })
382 .collect();
383
384 pre_migration.sort_by(|a, b| b.cmp(a));
386
387 for path in pre_migration.iter().skip(keep) {
388 let _ = std::fs::remove_file(path);
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395 use crate::storage::init_db;
396
397 async fn file_test_db(dir: &std::path::Path) -> (DbPool, PathBuf) {
399 let db_path = dir.join("test.db");
400 let pool = init_db(&db_path.to_string_lossy())
401 .await
402 .expect("init file db");
403 (pool, db_path)
404 }
405
406 #[tokio::test]
407 async fn create_and_validate_backup() {
408 let dir = tempfile::tempdir().expect("create temp dir");
409 let (pool, _db_path) = file_test_db(dir.path()).await;
410
411 sqlx::query(
413 "INSERT INTO action_log (action_type, status, message) \
414 VALUES ('test', 'success', 'backup test')",
415 )
416 .execute(&pool)
417 .await
418 .expect("insert");
419
420 let backup_dir = dir.path().join("backups");
421 let result = create_backup(&pool, &backup_dir).await.expect("backup");
422
423 assert!(result.path.exists());
424 assert!(result.size_bytes > 0);
425 assert!(result
426 .path
427 .file_name()
428 .unwrap()
429 .to_str()
430 .unwrap()
431 .starts_with("tuitbot_"));
432
433 let validation = validate_backup(&result.path).await.expect("validate");
435 assert!(validation.valid);
436 assert!(!validation.tables.is_empty());
437 assert!(validation.tables.contains(&"action_log".to_string()));
438
439 pool.close().await;
440 }
441
442 #[tokio::test]
443 async fn validate_nonexistent_file() {
444 let result = validate_backup(Path::new("/nonexistent/backup.db"))
445 .await
446 .expect("validate");
447 assert!(!result.valid);
448 }
449
450 #[tokio::test]
451 async fn list_and_prune_backups() {
452 let dir = tempfile::tempdir().expect("create temp dir");
453
454 for i in 1..=5 {
456 let name = format!("tuitbot_20240101_00000{i}.db");
457 std::fs::write(dir.path().join(name), "fake").expect("write");
458 }
459
460 let backups = list_backups(dir.path());
461 assert_eq!(backups.len(), 5);
462 assert!(
464 backups[0].timestamp.as_deref().unwrap() > backups[4].timestamp.as_deref().unwrap()
465 );
466
467 let pruned = prune_backups(dir.path(), 2).expect("prune");
469 assert_eq!(pruned, 3);
470
471 let remaining = list_backups(dir.path());
472 assert_eq!(remaining.len(), 2);
473 }
474
475 #[tokio::test]
476 async fn backup_and_restore() {
477 let dir = tempfile::tempdir().expect("create temp dir");
478 let (pool, _db_path) = file_test_db(dir.path()).await;
479
480 sqlx::query(
481 "INSERT INTO action_log (action_type, status, message) \
482 VALUES ('test', 'success', 'restore test')",
483 )
484 .execute(&pool)
485 .await
486 .expect("insert");
487
488 let backup_dir = dir.path().join("backups");
489 let result = create_backup(&pool, &backup_dir).await.expect("backup");
490 pool.close().await;
491
492 let target = dir.path().join("restored.db");
494
495 restore_from_backup(&result.path, &target)
496 .await
497 .expect("restore");
498
499 assert!(target.exists());
500
501 let options = SqliteConnectOptions::new()
503 .filename(target.to_string_lossy().as_ref())
504 .read_only(true);
505 let restored_pool = SqlitePoolOptions::new()
506 .max_connections(1)
507 .connect_with(options)
508 .await
509 .expect("open restored");
510
511 let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM action_log")
512 .fetch_one(&restored_pool)
513 .await
514 .expect("count");
515 assert_eq!(count.0, 1);
516 restored_pool.close().await;
517 }
518}