Skip to main content

rustio_core/
migrations.rs

1//! Forward-only migrations tracked in a SQLite table.
2//!
3//! Migrations are plain `.sql` files in a directory, named `NNNN_<slug>.sql`
4//! (auto-numbered by [`generate`]). [`apply`] runs pending migrations in
5//! filename order, each inside its own transaction, and records applied
6//! filenames in the `rustio_migrations` table so reruns are idempotent.
7
8use std::collections::HashSet;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use sqlx::Row as _;
13
14use crate::error::Error;
15use crate::orm::Db;
16
17const TRACKING_TABLE: &str = "rustio_migrations";
18
19/// Options for [`apply_with`].
20#[derive(Default, Clone, Copy, Debug)]
21pub struct ApplyOptions {
22    /// Print each statement to stderr before execution.
23    pub verbose: bool,
24}
25
26pub fn list(dir: &Path) -> Result<Vec<PathBuf>, Error> {
27    if !dir.exists() {
28        return Ok(Vec::new());
29    }
30    let entries = fs::read_dir(dir).map_err(|e| Error::Internal(e.to_string()))?;
31    let mut files: Vec<PathBuf> = entries
32        .filter_map(|e| e.ok())
33        .filter(|e| {
34            e.file_type().map(|t| t.is_file()).unwrap_or(false)
35                && e.path().extension().and_then(|s| s.to_str()) == Some("sql")
36        })
37        .map(|e| e.path())
38        .collect();
39    files.sort();
40    Ok(files)
41}
42
43pub fn generate(dir: &Path, name: &str, content: &str) -> Result<PathBuf, Error> {
44    let sanitized = sanitize_name(name);
45    if sanitized.is_empty() {
46        return Err(Error::BadRequest(
47            "migration name cannot be empty".to_string(),
48        ));
49    }
50    fs::create_dir_all(dir).map_err(|e| Error::Internal(e.to_string()))?;
51    let existing = list(dir)?;
52    let next = next_number(&existing);
53    let filename = format!("{:04}_{}.sql", next, sanitized);
54    let path = dir.join(filename);
55    fs::write(&path, content).map_err(|e| Error::Internal(e.to_string()))?;
56    Ok(path)
57}
58
59#[derive(Debug, Clone)]
60pub struct MigrationRecord {
61    pub filename: String,
62    pub applied_at: String,
63}
64
65#[derive(Debug)]
66pub struct Status {
67    pub applied: Vec<MigrationRecord>,
68    pub pending: Vec<String>,
69}
70
71pub async fn applied(db: &Db) -> Result<Vec<MigrationRecord>, Error> {
72    ensure_tracking_table(db).await?;
73    let rows = sqlx::query(&format!(
74        "SELECT filename, applied_at FROM {TRACKING_TABLE} ORDER BY filename"
75    ))
76    .fetch_all(db.pool())
77    .await?;
78    Ok(rows
79        .iter()
80        .map(|r| MigrationRecord {
81            filename: r.get(0),
82            applied_at: r.get(1),
83        })
84        .collect())
85}
86
87pub async fn status(db: &Db, dir: &Path) -> Result<Status, Error> {
88    let applied_records = applied(db).await?;
89    let applied_names: HashSet<String> =
90        applied_records.iter().map(|r| r.filename.clone()).collect();
91    let files = list(dir)?;
92    let pending: Vec<String> = files
93        .iter()
94        .filter_map(|p| p.file_name().and_then(|n| n.to_str()).map(String::from))
95        .filter(|n| !applied_names.contains(n))
96        .collect();
97    Ok(Status {
98        applied: applied_records,
99        pending,
100    })
101}
102
103pub async fn apply(db: &Db, dir: &Path) -> Result<Vec<String>, Error> {
104    apply_with(db, dir, ApplyOptions::default()).await
105}
106
107pub async fn apply_with(db: &Db, dir: &Path, opts: ApplyOptions) -> Result<Vec<String>, Error> {
108    ensure_tracking_table(db).await?;
109
110    let rows = sqlx::query(&format!("SELECT filename FROM {TRACKING_TABLE}"))
111        .fetch_all(db.pool())
112        .await?;
113    let already_applied: HashSet<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
114
115    let files = list(dir)?;
116    let mut newly_applied = Vec::new();
117
118    for path in files {
119        let filename = match path.file_name().and_then(|n| n.to_str()) {
120            Some(n) => n.to_string(),
121            None => continue,
122        };
123        if already_applied.contains(&filename) {
124            continue;
125        }
126
127        let sql = fs::read_to_string(&path)
128            .map_err(|e| Error::Internal(format!("reading {filename}: {e}")))?;
129
130        if opts.verbose {
131            eprintln!("-- applying {filename}");
132        }
133
134        let mut tx = db.pool().begin().await?;
135        for stmt in split_sql(&sql) {
136            if opts.verbose {
137                eprintln!("  {}", stmt);
138            }
139            sqlx::query(&stmt)
140                .execute(&mut *tx)
141                .await
142                .map_err(|e| Error::Internal(format!("migration {filename} failed: {e}")))?;
143        }
144        sqlx::query(&format!(
145            "INSERT INTO {TRACKING_TABLE} (filename) VALUES (?)"
146        ))
147        .bind(&filename)
148        .execute(&mut *tx)
149        .await?;
150        tx.commit().await?;
151
152        newly_applied.push(filename);
153    }
154
155    Ok(newly_applied)
156}
157
158async fn ensure_tracking_table(db: &Db) -> Result<(), Error> {
159    db.execute(&format!(
160        "CREATE TABLE IF NOT EXISTS {TRACKING_TABLE} (
161            filename TEXT PRIMARY KEY,
162            applied_at TEXT NOT NULL DEFAULT (datetime('now'))
163        )"
164    ))
165    .await
166}
167
168fn next_number(files: &[PathBuf]) -> u32 {
169    files
170        .iter()
171        .filter_map(|p| p.file_name()?.to_str())
172        .filter_map(|name| {
173            let (prefix, _) = name.split_once('_')?;
174            prefix.parse::<u32>().ok()
175        })
176        .max()
177        .map(|n| n + 1)
178        .unwrap_or(1)
179}
180
181fn sanitize_name(name: &str) -> String {
182    let mut out = String::new();
183    let mut last_sep = true;
184    for c in name.chars() {
185        if c.is_ascii_alphanumeric() {
186            for lc in c.to_lowercase() {
187                out.push(lc);
188            }
189            last_sep = false;
190        } else if !last_sep {
191            out.push('_');
192            last_sep = true;
193        }
194    }
195    out.trim_matches('_').to_string()
196}
197
198/// Split a migration file into individual statements.
199///
200/// Preserves semicolons that appear inside single-quoted string literals,
201/// `--` line comments, and `/* ... */` block comments. Doubled single
202/// quotes (`''`) inside a literal are treated as an escape and not as a
203/// string terminator.
204fn split_sql(sql: &str) -> Vec<String> {
205    let mut out = Vec::new();
206    let mut current = String::new();
207    let mut chars = sql.chars().peekable();
208
209    while let Some(c) = chars.next() {
210        match c {
211            '\'' => {
212                current.push(c);
213                loop {
214                    match chars.next() {
215                        Some('\'') => {
216                            current.push('\'');
217                            if chars.peek() == Some(&'\'') {
218                                current.push(chars.next().unwrap());
219                                continue;
220                            }
221                            break;
222                        }
223                        Some(other) => current.push(other),
224                        None => break,
225                    }
226                }
227            }
228            '-' if chars.peek() == Some(&'-') => {
229                current.push(c);
230                while let Some(&nc) = chars.peek() {
231                    chars.next();
232                    current.push(nc);
233                    if nc == '\n' {
234                        break;
235                    }
236                }
237            }
238            '/' if chars.peek() == Some(&'*') => {
239                current.push(c);
240                current.push(chars.next().unwrap());
241                while let Some(c1) = chars.next() {
242                    current.push(c1);
243                    if c1 == '*' && chars.peek() == Some(&'/') {
244                        current.push(chars.next().unwrap());
245                        break;
246                    }
247                }
248            }
249            ';' => {
250                let trimmed = current.trim();
251                if !trimmed.is_empty() {
252                    out.push(trimmed.to_string());
253                }
254                current.clear();
255            }
256            _ => current.push(c),
257        }
258    }
259
260    let trimmed = current.trim();
261    if !trimmed.is_empty() {
262        out.push(trimmed.to_string());
263    }
264
265    out
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    fn tmp(prefix: &str) -> PathBuf {
273        let path = std::env::temp_dir().join(format!(
274            "rustio-mig-{prefix}-{}-{}",
275            std::process::id(),
276            std::time::SystemTime::now()
277                .duration_since(std::time::UNIX_EPOCH)
278                .unwrap()
279                .as_nanos()
280        ));
281        let _ = fs::remove_dir_all(&path);
282        path
283    }
284
285    #[test]
286    fn sanitize_lowercases_and_underscores() {
287        assert_eq!(sanitize_name("Add Blog Table"), "add_blog_table");
288        assert_eq!(sanitize_name("create-users-table"), "create_users_table");
289        assert_eq!(sanitize_name("add  spaces"), "add_spaces");
290        assert_eq!(sanitize_name("CamelCase"), "camelcase");
291    }
292
293    #[test]
294    fn sanitize_trims_outer_separators() {
295        assert_eq!(sanitize_name("_add_"), "add");
296        assert_eq!(sanitize_name("--blog--"), "blog");
297    }
298
299    #[test]
300    fn sanitize_empty_returns_empty() {
301        assert_eq!(sanitize_name(""), "");
302        assert_eq!(sanitize_name("   "), "");
303        assert_eq!(sanitize_name("!!!"), "");
304    }
305
306    #[test]
307    fn next_number_starts_at_one() {
308        assert_eq!(next_number(&[]), 1);
309    }
310
311    #[test]
312    fn next_number_follows_highest() {
313        let files = vec![
314            PathBuf::from("migrations/0001_first.sql"),
315            PathBuf::from("migrations/0003_third.sql"),
316            PathBuf::from("migrations/0002_second.sql"),
317        ];
318        assert_eq!(next_number(&files), 4);
319    }
320
321    #[test]
322    fn next_number_ignores_non_numeric_prefixes() {
323        let files = vec![
324            PathBuf::from("migrations/readme.sql"),
325            PathBuf::from("migrations/0005_real.sql"),
326        ];
327        assert_eq!(next_number(&files), 6);
328    }
329
330    #[test]
331    fn split_sql_handles_multiple_statements() {
332        let sql = "CREATE TABLE a (id INT); CREATE TABLE b (id INT);";
333        let stmts = split_sql(sql);
334        assert_eq!(
335            stmts,
336            vec![
337                String::from("CREATE TABLE a (id INT)"),
338                String::from("CREATE TABLE b (id INT)"),
339            ]
340        );
341    }
342
343    #[test]
344    fn split_sql_ignores_empty_trailing() {
345        assert!(split_sql(";;  ;").is_empty());
346    }
347
348    #[test]
349    fn split_sql_preserves_semicolon_inside_string_literal() {
350        assert_eq!(
351            split_sql("INSERT INTO t VALUES ('a;b'); CREATE TABLE x (id INT);"),
352            vec![
353                String::from("INSERT INTO t VALUES ('a;b')"),
354                String::from("CREATE TABLE x (id INT)"),
355            ]
356        );
357    }
358
359    #[test]
360    fn split_sql_handles_escaped_single_quote() {
361        assert_eq!(
362            split_sql("INSERT VALUES ('it''s; fine');"),
363            vec![String::from("INSERT VALUES ('it''s; fine')")]
364        );
365    }
366
367    #[test]
368    fn split_sql_skips_semicolons_inside_line_comment() {
369        assert_eq!(
370            split_sql("-- first; second\nCREATE TABLE t (id INT);"),
371            vec![String::from("-- first; second\nCREATE TABLE t (id INT)")]
372        );
373    }
374
375    #[test]
376    fn split_sql_skips_semicolons_inside_block_comment() {
377        assert_eq!(
378            split_sql("/* a;b;c */ CREATE TABLE t (id INT);"),
379            vec![String::from("/* a;b;c */ CREATE TABLE t (id INT)")]
380        );
381    }
382
383    #[test]
384    fn generate_creates_files_with_numbered_prefixes() {
385        let dir = tmp("gen");
386        let p1 = generate(&dir, "create users", "-- one").unwrap();
387        let p2 = generate(&dir, "add index", "-- two").unwrap();
388        assert!(p1
389            .file_name()
390            .unwrap()
391            .to_string_lossy()
392            .starts_with("0001_create_users"));
393        assert!(p2
394            .file_name()
395            .unwrap()
396            .to_string_lossy()
397            .starts_with("0002_add_index"));
398        assert_eq!(fs::read_to_string(&p1).unwrap(), "-- one");
399        fs::remove_dir_all(&dir).ok();
400    }
401
402    #[test]
403    fn generate_rejects_empty_name_after_sanitization() {
404        let dir = tmp("gen-empty");
405        assert!(matches!(
406            generate(&dir, "!!!", ""),
407            Err(Error::BadRequest(_))
408        ));
409        fs::remove_dir_all(&dir).ok();
410    }
411
412    #[tokio::test]
413    async fn apply_creates_tracking_table_even_with_no_migrations() {
414        let db = Db::memory().await.unwrap();
415        let dir = tmp("apply-empty");
416        fs::create_dir_all(&dir).unwrap();
417        let applied = apply(&db, &dir).await.unwrap();
418        assert!(applied.is_empty());
419        let row = sqlx::query("SELECT COUNT(*) FROM rustio_migrations")
420            .fetch_one(db.pool())
421            .await
422            .unwrap();
423        let count: i64 = row.get(0);
424        assert_eq!(count, 0);
425        fs::remove_dir_all(&dir).ok();
426    }
427
428    #[tokio::test]
429    async fn apply_runs_pending_and_is_idempotent() {
430        let db = Db::memory().await.unwrap();
431        let dir = tmp("apply-idem");
432        fs::create_dir_all(&dir).unwrap();
433        fs::write(dir.join("0001_create.sql"), "CREATE TABLE t (id INTEGER);").unwrap();
434        fs::write(
435            dir.join("0002_insert.sql"),
436            "INSERT INTO t (id) VALUES (42);",
437        )
438        .unwrap();
439
440        let first = apply(&db, &dir).await.unwrap();
441        assert_eq!(first, vec!["0001_create.sql", "0002_insert.sql"]);
442
443        let second = apply(&db, &dir).await.unwrap();
444        assert!(second.is_empty());
445
446        let row = sqlx::query("SELECT id FROM t")
447            .fetch_one(db.pool())
448            .await
449            .unwrap();
450        let id: i64 = row.get(0);
451        assert_eq!(id, 42);
452
453        fs::remove_dir_all(&dir).ok();
454    }
455
456    #[tokio::test]
457    async fn apply_picks_up_new_migration_added_later() {
458        let db = Db::memory().await.unwrap();
459        let dir = tmp("apply-followup");
460        fs::create_dir_all(&dir).unwrap();
461        fs::write(
462            dir.join("0001_first.sql"),
463            "CREATE TABLE first (id INTEGER);",
464        )
465        .unwrap();
466        apply(&db, &dir).await.unwrap();
467
468        fs::write(
469            dir.join("0002_second.sql"),
470            "CREATE TABLE second (id INTEGER);",
471        )
472        .unwrap();
473        let applied = apply(&db, &dir).await.unwrap();
474        assert_eq!(applied, vec!["0002_second.sql"]);
475
476        sqlx::query("INSERT INTO first (id) VALUES (1)")
477            .execute(db.pool())
478            .await
479            .unwrap();
480        sqlx::query("INSERT INTO second (id) VALUES (2)")
481            .execute(db.pool())
482            .await
483            .unwrap();
484
485        fs::remove_dir_all(&dir).ok();
486    }
487
488    #[tokio::test]
489    async fn status_reports_applied_and_pending_separately() {
490        let db = Db::memory().await.unwrap();
491        let dir = tmp("status");
492        fs::create_dir_all(&dir).unwrap();
493        fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
494        fs::write(dir.join("0002_b.sql"), "CREATE TABLE b (id INTEGER);").unwrap();
495        fs::write(dir.join("0003_c.sql"), "CREATE TABLE c (id INTEGER);").unwrap();
496
497        // Apply only 0001 and 0002 by isolating them
498        fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
499        let applied_now = apply(&db, &dir).await.unwrap();
500        assert_eq!(applied_now.len(), 3);
501
502        // Add a fourth, not yet applied
503        fs::write(dir.join("0004_d.sql"), "CREATE TABLE d (id INTEGER);").unwrap();
504
505        let s = status(&db, &dir).await.unwrap();
506        assert_eq!(s.applied.len(), 3);
507        assert_eq!(
508            s.applied
509                .iter()
510                .map(|r| r.filename.as_str())
511                .collect::<Vec<_>>(),
512            vec!["0001_a.sql", "0002_b.sql", "0003_c.sql"]
513        );
514        assert_eq!(s.pending, vec!["0004_d.sql"]);
515
516        fs::remove_dir_all(&dir).ok();
517    }
518
519    #[tokio::test]
520    async fn status_on_empty_project_returns_empty_both() {
521        let db = Db::memory().await.unwrap();
522        let dir = tmp("status-empty");
523        fs::create_dir_all(&dir).unwrap();
524        let s = status(&db, &dir).await.unwrap();
525        assert!(s.applied.is_empty());
526        assert!(s.pending.is_empty());
527        fs::remove_dir_all(&dir).ok();
528    }
529
530    #[tokio::test]
531    async fn failed_migration_rolls_back_and_is_not_marked_applied() {
532        let db = Db::memory().await.unwrap();
533        let dir = tmp("apply-failure");
534        fs::create_dir_all(&dir).unwrap();
535        fs::write(dir.join("0001_ok.sql"), "CREATE TABLE ok (id INTEGER);").unwrap();
536        fs::write(dir.join("0002_bad.sql"), "CREATE TABLE ok (id INTEGER);").unwrap(); // duplicate name → fails
537
538        let result = apply(&db, &dir).await;
539        assert!(result.is_err());
540
541        let rows = sqlx::query("SELECT filename FROM rustio_migrations")
542            .fetch_all(db.pool())
543            .await
544            .unwrap();
545        let applied: Vec<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
546        assert_eq!(applied, vec!["0001_ok.sql"]);
547
548        fs::remove_dir_all(&dir).ok();
549    }
550}