Skip to main content

rustio_core/
migrations.rs

1//! Forward-only migrations tracked in a SQLite table.
2//!
3//! Migrations are plain `.sql` files in a directory, named `NNNN_<slug>.sql`
4//! (auto-numbered by [`generate`]). [`apply`] runs pending migrations in
5//! filename order, each inside its own transaction, and records applied
6//! filenames in the `rustio_migrations` table so reruns are idempotent.
7
8use std::collections::HashSet;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use sqlx::Row as _;
13
14use crate::error::Error;
15use crate::orm::Db;
16
17const TRACKING_TABLE: &str = "rustio_migrations";
18
19/// Options for [`apply_with`].
20#[derive(Default, Clone, Copy, Debug)]
21pub struct ApplyOptions {
22    /// Print each statement to stderr before execution.
23    pub verbose: bool,
24}
25
26pub fn list(dir: &Path) -> Result<Vec<PathBuf>, Error> {
27    if !dir.exists() {
28        return Ok(Vec::new());
29    }
30    let entries = fs::read_dir(dir).map_err(|e| Error::Internal(e.to_string()))?;
31    let mut files: Vec<PathBuf> = entries
32        .filter_map(|e| e.ok())
33        .filter(|e| {
34            e.file_type().map(|t| t.is_file()).unwrap_or(false)
35                && e.path().extension().and_then(|s| s.to_str()) == Some("sql")
36        })
37        .map(|e| e.path())
38        .collect();
39    files.sort();
40    Ok(files)
41}
42
43pub fn generate(dir: &Path, name: &str, content: &str) -> Result<PathBuf, Error> {
44    let sanitized = sanitize_name(name);
45    if sanitized.is_empty() {
46        return Err(Error::BadRequest(
47            "migration name cannot be empty".to_string(),
48        ));
49    }
50    fs::create_dir_all(dir).map_err(|e| Error::Internal(e.to_string()))?;
51    let existing = list(dir)?;
52    let next = next_number(&existing);
53    let filename = format!("{:04}_{}.sql", next, sanitized);
54    let path = dir.join(filename);
55    fs::write(&path, content).map_err(|e| Error::Internal(e.to_string()))?;
56    Ok(path)
57}
58
59#[derive(Debug, Clone)]
60pub struct MigrationRecord {
61    pub filename: String,
62    pub applied_at: String,
63}
64
65#[derive(Debug)]
66pub struct Status {
67    pub applied: Vec<MigrationRecord>,
68    pub pending: Vec<String>,
69}
70
71pub async fn applied(db: &Db) -> Result<Vec<MigrationRecord>, Error> {
72    ensure_tracking_table(db).await?;
73    let rows = sqlx::query(&format!(
74        "SELECT filename, applied_at FROM {TRACKING_TABLE} ORDER BY filename"
75    ))
76    .fetch_all(db.pool())
77    .await?;
78    Ok(rows
79        .iter()
80        .map(|r| MigrationRecord {
81            filename: r.get(0),
82            applied_at: r.get(1),
83        })
84        .collect())
85}
86
87pub async fn status(db: &Db, dir: &Path) -> Result<Status, Error> {
88    let applied_records = applied(db).await?;
89    let applied_names: HashSet<String> =
90        applied_records.iter().map(|r| r.filename.clone()).collect();
91    let files = list(dir)?;
92    let pending: Vec<String> = files
93        .iter()
94        .filter_map(|p| p.file_name().and_then(|n| n.to_str()).map(String::from))
95        .filter(|n| !applied_names.contains(n))
96        .collect();
97    Ok(Status {
98        applied: applied_records,
99        pending,
100    })
101}
102
103pub async fn apply(db: &Db, dir: &Path) -> Result<Vec<String>, Error> {
104    apply_with(db, dir, ApplyOptions::default()).await
105}
106
107pub async fn apply_with(db: &Db, dir: &Path, opts: ApplyOptions) -> Result<Vec<String>, Error> {
108    // Core auth tables (`rustio_users`, `rustio_sessions`) exist on
109    // every DB. They're created before any user-level migration runs
110    // so login works as soon as a project boots, regardless of what's
111    // in `migrations/`. The statements are `CREATE IF NOT EXISTS` so
112    // calling this repeatedly is safe.
113    crate::auth::ensure_core_tables(db).await?;
114    ensure_tracking_table(db).await?;
115
116    let rows = sqlx::query(&format!("SELECT filename FROM {TRACKING_TABLE}"))
117        .fetch_all(db.pool())
118        .await?;
119    let already_applied: HashSet<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
120
121    let files = list(dir)?;
122    let mut newly_applied = Vec::new();
123
124    for path in files {
125        let filename = match path.file_name().and_then(|n| n.to_str()) {
126            Some(n) => n.to_string(),
127            None => continue,
128        };
129        if already_applied.contains(&filename) {
130            continue;
131        }
132
133        let sql = fs::read_to_string(&path)
134            .map_err(|e| Error::Internal(format!("reading {filename}: {e}")))?;
135
136        if opts.verbose {
137            eprintln!("-- applying {filename}");
138        }
139
140        let mut tx = db.pool().begin().await?;
141        for stmt in split_sql(&sql) {
142            if opts.verbose {
143                eprintln!("  {}", stmt);
144            }
145            sqlx::query(&stmt)
146                .execute(&mut *tx)
147                .await
148                .map_err(|e| Error::Internal(format!("migration {filename} failed: {e}")))?;
149        }
150        sqlx::query(&format!(
151            "INSERT INTO {TRACKING_TABLE} (filename) VALUES (?)"
152        ))
153        .bind(&filename)
154        .execute(&mut *tx)
155        .await?;
156        tx.commit().await?;
157
158        newly_applied.push(filename);
159    }
160
161    Ok(newly_applied)
162}
163
164async fn ensure_tracking_table(db: &Db) -> Result<(), Error> {
165    db.execute(&format!(
166        "CREATE TABLE IF NOT EXISTS {TRACKING_TABLE} (
167            filename TEXT PRIMARY KEY,
168            applied_at TEXT NOT NULL DEFAULT (datetime('now'))
169        )"
170    ))
171    .await
172}
173
174fn next_number(files: &[PathBuf]) -> u32 {
175    files
176        .iter()
177        .filter_map(|p| p.file_name()?.to_str())
178        .filter_map(|name| {
179            let (prefix, _) = name.split_once('_')?;
180            prefix.parse::<u32>().ok()
181        })
182        .max()
183        .map(|n| n + 1)
184        .unwrap_or(1)
185}
186
187fn sanitize_name(name: &str) -> String {
188    let mut out = String::new();
189    let mut last_sep = true;
190    for c in name.chars() {
191        if c.is_ascii_alphanumeric() {
192            for lc in c.to_lowercase() {
193                out.push(lc);
194            }
195            last_sep = false;
196        } else if !last_sep {
197            out.push('_');
198            last_sep = true;
199        }
200    }
201    out.trim_matches('_').to_string()
202}
203
204/// Split a migration file into individual statements.
205///
206/// Preserves semicolons that appear inside single-quoted string literals,
207/// `--` line comments, and `/* ... */` block comments. Doubled single
208/// quotes (`''`) inside a literal are treated as an escape and not as a
209/// string terminator.
210fn split_sql(sql: &str) -> Vec<String> {
211    let mut out = Vec::new();
212    let mut current = String::new();
213    let mut chars = sql.chars().peekable();
214
215    while let Some(c) = chars.next() {
216        match c {
217            '\'' => {
218                current.push(c);
219                loop {
220                    match chars.next() {
221                        Some('\'') => {
222                            current.push('\'');
223                            if chars.peek() == Some(&'\'') {
224                                current.push(chars.next().unwrap());
225                                continue;
226                            }
227                            break;
228                        }
229                        Some(other) => current.push(other),
230                        None => break,
231                    }
232                }
233            }
234            '-' if chars.peek() == Some(&'-') => {
235                current.push(c);
236                while let Some(&nc) = chars.peek() {
237                    chars.next();
238                    current.push(nc);
239                    if nc == '\n' {
240                        break;
241                    }
242                }
243            }
244            '/' if chars.peek() == Some(&'*') => {
245                current.push(c);
246                current.push(chars.next().unwrap());
247                while let Some(c1) = chars.next() {
248                    current.push(c1);
249                    if c1 == '*' && chars.peek() == Some(&'/') {
250                        current.push(chars.next().unwrap());
251                        break;
252                    }
253                }
254            }
255            ';' => {
256                let trimmed = current.trim();
257                if !trimmed.is_empty() {
258                    out.push(trimmed.to_string());
259                }
260                current.clear();
261            }
262            _ => current.push(c),
263        }
264    }
265
266    let trimmed = current.trim();
267    if !trimmed.is_empty() {
268        out.push(trimmed.to_string());
269    }
270
271    out
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    fn tmp(prefix: &str) -> PathBuf {
279        let path = std::env::temp_dir().join(format!(
280            "rustio-mig-{prefix}-{}-{}",
281            std::process::id(),
282            std::time::SystemTime::now()
283                .duration_since(std::time::UNIX_EPOCH)
284                .unwrap()
285                .as_nanos()
286        ));
287        let _ = fs::remove_dir_all(&path);
288        path
289    }
290
291    #[test]
292    fn sanitize_lowercases_and_underscores() {
293        assert_eq!(sanitize_name("Add Blog Table"), "add_blog_table");
294        assert_eq!(sanitize_name("create-users-table"), "create_users_table");
295        assert_eq!(sanitize_name("add  spaces"), "add_spaces");
296        assert_eq!(sanitize_name("CamelCase"), "camelcase");
297    }
298
299    #[test]
300    fn sanitize_trims_outer_separators() {
301        assert_eq!(sanitize_name("_add_"), "add");
302        assert_eq!(sanitize_name("--blog--"), "blog");
303    }
304
305    #[test]
306    fn sanitize_empty_returns_empty() {
307        assert_eq!(sanitize_name(""), "");
308        assert_eq!(sanitize_name("   "), "");
309        assert_eq!(sanitize_name("!!!"), "");
310    }
311
312    #[test]
313    fn next_number_starts_at_one() {
314        assert_eq!(next_number(&[]), 1);
315    }
316
317    #[test]
318    fn next_number_follows_highest() {
319        let files = vec![
320            PathBuf::from("migrations/0001_first.sql"),
321            PathBuf::from("migrations/0003_third.sql"),
322            PathBuf::from("migrations/0002_second.sql"),
323        ];
324        assert_eq!(next_number(&files), 4);
325    }
326
327    #[test]
328    fn next_number_ignores_non_numeric_prefixes() {
329        let files = vec![
330            PathBuf::from("migrations/readme.sql"),
331            PathBuf::from("migrations/0005_real.sql"),
332        ];
333        assert_eq!(next_number(&files), 6);
334    }
335
336    #[test]
337    fn split_sql_handles_multiple_statements() {
338        let sql = "CREATE TABLE a (id INT); CREATE TABLE b (id INT);";
339        let stmts = split_sql(sql);
340        assert_eq!(
341            stmts,
342            vec![
343                String::from("CREATE TABLE a (id INT)"),
344                String::from("CREATE TABLE b (id INT)"),
345            ]
346        );
347    }
348
349    #[test]
350    fn split_sql_ignores_empty_trailing() {
351        assert!(split_sql(";;  ;").is_empty());
352    }
353
354    #[test]
355    fn split_sql_preserves_semicolon_inside_string_literal() {
356        assert_eq!(
357            split_sql("INSERT INTO t VALUES ('a;b'); CREATE TABLE x (id INT);"),
358            vec![
359                String::from("INSERT INTO t VALUES ('a;b')"),
360                String::from("CREATE TABLE x (id INT)"),
361            ]
362        );
363    }
364
365    #[test]
366    fn split_sql_handles_escaped_single_quote() {
367        assert_eq!(
368            split_sql("INSERT VALUES ('it''s; fine');"),
369            vec![String::from("INSERT VALUES ('it''s; fine')")]
370        );
371    }
372
373    #[test]
374    fn split_sql_skips_semicolons_inside_line_comment() {
375        assert_eq!(
376            split_sql("-- first; second\nCREATE TABLE t (id INT);"),
377            vec![String::from("-- first; second\nCREATE TABLE t (id INT)")]
378        );
379    }
380
381    #[test]
382    fn split_sql_skips_semicolons_inside_block_comment() {
383        assert_eq!(
384            split_sql("/* a;b;c */ CREATE TABLE t (id INT);"),
385            vec![String::from("/* a;b;c */ CREATE TABLE t (id INT)")]
386        );
387    }
388
389    #[test]
390    fn generate_creates_files_with_numbered_prefixes() {
391        let dir = tmp("gen");
392        let p1 = generate(&dir, "create users", "-- one").unwrap();
393        let p2 = generate(&dir, "add index", "-- two").unwrap();
394        assert!(p1
395            .file_name()
396            .unwrap()
397            .to_string_lossy()
398            .starts_with("0001_create_users"));
399        assert!(p2
400            .file_name()
401            .unwrap()
402            .to_string_lossy()
403            .starts_with("0002_add_index"));
404        assert_eq!(fs::read_to_string(&p1).unwrap(), "-- one");
405        fs::remove_dir_all(&dir).ok();
406    }
407
408    #[test]
409    fn generate_rejects_empty_name_after_sanitization() {
410        let dir = tmp("gen-empty");
411        assert!(matches!(
412            generate(&dir, "!!!", ""),
413            Err(Error::BadRequest(_))
414        ));
415        fs::remove_dir_all(&dir).ok();
416    }
417
418    #[tokio::test]
419    async fn apply_creates_tracking_table_even_with_no_migrations() {
420        let db = Db::memory().await.unwrap();
421        let dir = tmp("apply-empty");
422        fs::create_dir_all(&dir).unwrap();
423        let applied = apply(&db, &dir).await.unwrap();
424        assert!(applied.is_empty());
425        let row = sqlx::query("SELECT COUNT(*) FROM rustio_migrations")
426            .fetch_one(db.pool())
427            .await
428            .unwrap();
429        let count: i64 = row.get(0);
430        assert_eq!(count, 0);
431        fs::remove_dir_all(&dir).ok();
432    }
433
434    #[tokio::test]
435    async fn apply_runs_pending_and_is_idempotent() {
436        let db = Db::memory().await.unwrap();
437        let dir = tmp("apply-idem");
438        fs::create_dir_all(&dir).unwrap();
439        fs::write(dir.join("0001_create.sql"), "CREATE TABLE t (id INTEGER);").unwrap();
440        fs::write(
441            dir.join("0002_insert.sql"),
442            "INSERT INTO t (id) VALUES (42);",
443        )
444        .unwrap();
445
446        let first = apply(&db, &dir).await.unwrap();
447        assert_eq!(first, vec!["0001_create.sql", "0002_insert.sql"]);
448
449        let second = apply(&db, &dir).await.unwrap();
450        assert!(second.is_empty());
451
452        let row = sqlx::query("SELECT id FROM t")
453            .fetch_one(db.pool())
454            .await
455            .unwrap();
456        let id: i64 = row.get(0);
457        assert_eq!(id, 42);
458
459        fs::remove_dir_all(&dir).ok();
460    }
461
462    #[tokio::test]
463    async fn apply_picks_up_new_migration_added_later() {
464        let db = Db::memory().await.unwrap();
465        let dir = tmp("apply-followup");
466        fs::create_dir_all(&dir).unwrap();
467        fs::write(
468            dir.join("0001_first.sql"),
469            "CREATE TABLE first (id INTEGER);",
470        )
471        .unwrap();
472        apply(&db, &dir).await.unwrap();
473
474        fs::write(
475            dir.join("0002_second.sql"),
476            "CREATE TABLE second (id INTEGER);",
477        )
478        .unwrap();
479        let applied = apply(&db, &dir).await.unwrap();
480        assert_eq!(applied, vec!["0002_second.sql"]);
481
482        sqlx::query("INSERT INTO first (id) VALUES (1)")
483            .execute(db.pool())
484            .await
485            .unwrap();
486        sqlx::query("INSERT INTO second (id) VALUES (2)")
487            .execute(db.pool())
488            .await
489            .unwrap();
490
491        fs::remove_dir_all(&dir).ok();
492    }
493
494    #[tokio::test]
495    async fn status_reports_applied_and_pending_separately() {
496        let db = Db::memory().await.unwrap();
497        let dir = tmp("status");
498        fs::create_dir_all(&dir).unwrap();
499        fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
500        fs::write(dir.join("0002_b.sql"), "CREATE TABLE b (id INTEGER);").unwrap();
501        fs::write(dir.join("0003_c.sql"), "CREATE TABLE c (id INTEGER);").unwrap();
502
503        // Apply only 0001 and 0002 by isolating them
504        fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
505        let applied_now = apply(&db, &dir).await.unwrap();
506        assert_eq!(applied_now.len(), 3);
507
508        // Add a fourth, not yet applied
509        fs::write(dir.join("0004_d.sql"), "CREATE TABLE d (id INTEGER);").unwrap();
510
511        let s = status(&db, &dir).await.unwrap();
512        assert_eq!(s.applied.len(), 3);
513        assert_eq!(
514            s.applied
515                .iter()
516                .map(|r| r.filename.as_str())
517                .collect::<Vec<_>>(),
518            vec!["0001_a.sql", "0002_b.sql", "0003_c.sql"]
519        );
520        assert_eq!(s.pending, vec!["0004_d.sql"]);
521
522        fs::remove_dir_all(&dir).ok();
523    }
524
525    #[tokio::test]
526    async fn status_on_empty_project_returns_empty_both() {
527        let db = Db::memory().await.unwrap();
528        let dir = tmp("status-empty");
529        fs::create_dir_all(&dir).unwrap();
530        let s = status(&db, &dir).await.unwrap();
531        assert!(s.applied.is_empty());
532        assert!(s.pending.is_empty());
533        fs::remove_dir_all(&dir).ok();
534    }
535
536    #[tokio::test]
537    async fn failed_migration_rolls_back_and_is_not_marked_applied() {
538        let db = Db::memory().await.unwrap();
539        let dir = tmp("apply-failure");
540        fs::create_dir_all(&dir).unwrap();
541        fs::write(dir.join("0001_ok.sql"), "CREATE TABLE ok (id INTEGER);").unwrap();
542        fs::write(dir.join("0002_bad.sql"), "CREATE TABLE ok (id INTEGER);").unwrap(); // duplicate name → fails
543
544        let result = apply(&db, &dir).await;
545        assert!(result.is_err());
546
547        let rows = sqlx::query("SELECT filename FROM rustio_migrations")
548            .fetch_all(db.pool())
549            .await
550            .unwrap();
551        let applied: Vec<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
552        assert_eq!(applied, vec!["0001_ok.sql"]);
553
554        fs::remove_dir_all(&dir).ok();
555    }
556}