1use std::collections::HashSet;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use sqlx::Row as _;
13
14use crate::error::Error;
15use crate::orm::Db;
16
17const TRACKING_TABLE: &str = "rustio_migrations";
18
19#[derive(Default, Clone, Copy, Debug)]
21pub struct ApplyOptions {
22 pub verbose: bool,
24}
25
26pub fn list(dir: &Path) -> Result<Vec<PathBuf>, Error> {
27 if !dir.exists() {
28 return Ok(Vec::new());
29 }
30 let entries = fs::read_dir(dir).map_err(|e| Error::Internal(e.to_string()))?;
31 let mut files: Vec<PathBuf> = entries
32 .filter_map(|e| e.ok())
33 .filter(|e| {
34 e.file_type().map(|t| t.is_file()).unwrap_or(false)
35 && e.path().extension().and_then(|s| s.to_str()) == Some("sql")
36 })
37 .map(|e| e.path())
38 .collect();
39 files.sort();
40 Ok(files)
41}
42
43pub fn generate(dir: &Path, name: &str, content: &str) -> Result<PathBuf, Error> {
44 let sanitized = sanitize_name(name);
45 if sanitized.is_empty() {
46 return Err(Error::BadRequest(
47 "migration name cannot be empty".to_string(),
48 ));
49 }
50 fs::create_dir_all(dir).map_err(|e| Error::Internal(e.to_string()))?;
51 let existing = list(dir)?;
52 let next = next_number(&existing);
53 let filename = format!("{:04}_{}.sql", next, sanitized);
54 let path = dir.join(filename);
55 fs::write(&path, content).map_err(|e| Error::Internal(e.to_string()))?;
56 Ok(path)
57}
58
59#[derive(Debug, Clone)]
60pub struct MigrationRecord {
61 pub filename: String,
62 pub applied_at: String,
63}
64
65#[derive(Debug)]
66pub struct Status {
67 pub applied: Vec<MigrationRecord>,
68 pub pending: Vec<String>,
69}
70
71pub async fn applied(db: &Db) -> Result<Vec<MigrationRecord>, Error> {
72 ensure_tracking_table(db).await?;
73 let rows = sqlx::query(&format!(
74 "SELECT filename, applied_at FROM {TRACKING_TABLE} ORDER BY filename"
75 ))
76 .fetch_all(db.pool())
77 .await?;
78 Ok(rows
79 .iter()
80 .map(|r| MigrationRecord {
81 filename: r.get(0),
82 applied_at: r.get(1),
83 })
84 .collect())
85}
86
87pub async fn status(db: &Db, dir: &Path) -> Result<Status, Error> {
88 let applied_records = applied(db).await?;
89 let applied_names: HashSet<String> =
90 applied_records.iter().map(|r| r.filename.clone()).collect();
91 let files = list(dir)?;
92 let pending: Vec<String> = files
93 .iter()
94 .filter_map(|p| p.file_name().and_then(|n| n.to_str()).map(String::from))
95 .filter(|n| !applied_names.contains(n))
96 .collect();
97 Ok(Status {
98 applied: applied_records,
99 pending,
100 })
101}
102
103pub async fn apply(db: &Db, dir: &Path) -> Result<Vec<String>, Error> {
104 apply_with(db, dir, ApplyOptions::default()).await
105}
106
107pub async fn apply_with(db: &Db, dir: &Path, opts: ApplyOptions) -> Result<Vec<String>, Error> {
108 ensure_tracking_table(db).await?;
109
110 let rows = sqlx::query(&format!("SELECT filename FROM {TRACKING_TABLE}"))
111 .fetch_all(db.pool())
112 .await?;
113 let already_applied: HashSet<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
114
115 let files = list(dir)?;
116 let mut newly_applied = Vec::new();
117
118 for path in files {
119 let filename = match path.file_name().and_then(|n| n.to_str()) {
120 Some(n) => n.to_string(),
121 None => continue,
122 };
123 if already_applied.contains(&filename) {
124 continue;
125 }
126
127 let sql = fs::read_to_string(&path)
128 .map_err(|e| Error::Internal(format!("reading {filename}: {e}")))?;
129
130 if opts.verbose {
131 eprintln!("-- applying {filename}");
132 }
133
134 let mut tx = db.pool().begin().await?;
135 for stmt in split_sql(&sql) {
136 if opts.verbose {
137 eprintln!(" {}", stmt);
138 }
139 sqlx::query(&stmt)
140 .execute(&mut *tx)
141 .await
142 .map_err(|e| Error::Internal(format!("migration {filename} failed: {e}")))?;
143 }
144 sqlx::query(&format!(
145 "INSERT INTO {TRACKING_TABLE} (filename) VALUES (?)"
146 ))
147 .bind(&filename)
148 .execute(&mut *tx)
149 .await?;
150 tx.commit().await?;
151
152 newly_applied.push(filename);
153 }
154
155 Ok(newly_applied)
156}
157
158async fn ensure_tracking_table(db: &Db) -> Result<(), Error> {
159 db.execute(&format!(
160 "CREATE TABLE IF NOT EXISTS {TRACKING_TABLE} (
161 filename TEXT PRIMARY KEY,
162 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
163 )"
164 ))
165 .await
166}
167
168fn next_number(files: &[PathBuf]) -> u32 {
169 files
170 .iter()
171 .filter_map(|p| p.file_name()?.to_str())
172 .filter_map(|name| {
173 let (prefix, _) = name.split_once('_')?;
174 prefix.parse::<u32>().ok()
175 })
176 .max()
177 .map(|n| n + 1)
178 .unwrap_or(1)
179}
180
181fn sanitize_name(name: &str) -> String {
182 let mut out = String::new();
183 let mut last_sep = true;
184 for c in name.chars() {
185 if c.is_ascii_alphanumeric() {
186 for lc in c.to_lowercase() {
187 out.push(lc);
188 }
189 last_sep = false;
190 } else if !last_sep {
191 out.push('_');
192 last_sep = true;
193 }
194 }
195 out.trim_matches('_').to_string()
196}
197
198fn split_sql(sql: &str) -> Vec<String> {
205 let mut out = Vec::new();
206 let mut current = String::new();
207 let mut chars = sql.chars().peekable();
208
209 while let Some(c) = chars.next() {
210 match c {
211 '\'' => {
212 current.push(c);
213 loop {
214 match chars.next() {
215 Some('\'') => {
216 current.push('\'');
217 if chars.peek() == Some(&'\'') {
218 current.push(chars.next().unwrap());
219 continue;
220 }
221 break;
222 }
223 Some(other) => current.push(other),
224 None => break,
225 }
226 }
227 }
228 '-' if chars.peek() == Some(&'-') => {
229 current.push(c);
230 while let Some(&nc) = chars.peek() {
231 chars.next();
232 current.push(nc);
233 if nc == '\n' {
234 break;
235 }
236 }
237 }
238 '/' if chars.peek() == Some(&'*') => {
239 current.push(c);
240 current.push(chars.next().unwrap());
241 while let Some(c1) = chars.next() {
242 current.push(c1);
243 if c1 == '*' && chars.peek() == Some(&'/') {
244 current.push(chars.next().unwrap());
245 break;
246 }
247 }
248 }
249 ';' => {
250 let trimmed = current.trim();
251 if !trimmed.is_empty() {
252 out.push(trimmed.to_string());
253 }
254 current.clear();
255 }
256 _ => current.push(c),
257 }
258 }
259
260 let trimmed = current.trim();
261 if !trimmed.is_empty() {
262 out.push(trimmed.to_string());
263 }
264
265 out
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 fn tmp(prefix: &str) -> PathBuf {
273 let path = std::env::temp_dir().join(format!(
274 "rustio-mig-{prefix}-{}-{}",
275 std::process::id(),
276 std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap()
279 .as_nanos()
280 ));
281 let _ = fs::remove_dir_all(&path);
282 path
283 }
284
285 #[test]
286 fn sanitize_lowercases_and_underscores() {
287 assert_eq!(sanitize_name("Add Blog Table"), "add_blog_table");
288 assert_eq!(sanitize_name("create-users-table"), "create_users_table");
289 assert_eq!(sanitize_name("add spaces"), "add_spaces");
290 assert_eq!(sanitize_name("CamelCase"), "camelcase");
291 }
292
293 #[test]
294 fn sanitize_trims_outer_separators() {
295 assert_eq!(sanitize_name("_add_"), "add");
296 assert_eq!(sanitize_name("--blog--"), "blog");
297 }
298
299 #[test]
300 fn sanitize_empty_returns_empty() {
301 assert_eq!(sanitize_name(""), "");
302 assert_eq!(sanitize_name(" "), "");
303 assert_eq!(sanitize_name("!!!"), "");
304 }
305
306 #[test]
307 fn next_number_starts_at_one() {
308 assert_eq!(next_number(&[]), 1);
309 }
310
311 #[test]
312 fn next_number_follows_highest() {
313 let files = vec![
314 PathBuf::from("migrations/0001_first.sql"),
315 PathBuf::from("migrations/0003_third.sql"),
316 PathBuf::from("migrations/0002_second.sql"),
317 ];
318 assert_eq!(next_number(&files), 4);
319 }
320
321 #[test]
322 fn next_number_ignores_non_numeric_prefixes() {
323 let files = vec![
324 PathBuf::from("migrations/readme.sql"),
325 PathBuf::from("migrations/0005_real.sql"),
326 ];
327 assert_eq!(next_number(&files), 6);
328 }
329
330 #[test]
331 fn split_sql_handles_multiple_statements() {
332 let sql = "CREATE TABLE a (id INT); CREATE TABLE b (id INT);";
333 let stmts = split_sql(sql);
334 assert_eq!(
335 stmts,
336 vec![
337 String::from("CREATE TABLE a (id INT)"),
338 String::from("CREATE TABLE b (id INT)"),
339 ]
340 );
341 }
342
343 #[test]
344 fn split_sql_ignores_empty_trailing() {
345 assert!(split_sql(";; ;").is_empty());
346 }
347
348 #[test]
349 fn split_sql_preserves_semicolon_inside_string_literal() {
350 assert_eq!(
351 split_sql("INSERT INTO t VALUES ('a;b'); CREATE TABLE x (id INT);"),
352 vec![
353 String::from("INSERT INTO t VALUES ('a;b')"),
354 String::from("CREATE TABLE x (id INT)"),
355 ]
356 );
357 }
358
359 #[test]
360 fn split_sql_handles_escaped_single_quote() {
361 assert_eq!(
362 split_sql("INSERT VALUES ('it''s; fine');"),
363 vec![String::from("INSERT VALUES ('it''s; fine')")]
364 );
365 }
366
367 #[test]
368 fn split_sql_skips_semicolons_inside_line_comment() {
369 assert_eq!(
370 split_sql("-- first; second\nCREATE TABLE t (id INT);"),
371 vec![String::from("-- first; second\nCREATE TABLE t (id INT)")]
372 );
373 }
374
375 #[test]
376 fn split_sql_skips_semicolons_inside_block_comment() {
377 assert_eq!(
378 split_sql("/* a;b;c */ CREATE TABLE t (id INT);"),
379 vec![String::from("/* a;b;c */ CREATE TABLE t (id INT)")]
380 );
381 }
382
383 #[test]
384 fn generate_creates_files_with_numbered_prefixes() {
385 let dir = tmp("gen");
386 let p1 = generate(&dir, "create users", "-- one").unwrap();
387 let p2 = generate(&dir, "add index", "-- two").unwrap();
388 assert!(p1
389 .file_name()
390 .unwrap()
391 .to_string_lossy()
392 .starts_with("0001_create_users"));
393 assert!(p2
394 .file_name()
395 .unwrap()
396 .to_string_lossy()
397 .starts_with("0002_add_index"));
398 assert_eq!(fs::read_to_string(&p1).unwrap(), "-- one");
399 fs::remove_dir_all(&dir).ok();
400 }
401
402 #[test]
403 fn generate_rejects_empty_name_after_sanitization() {
404 let dir = tmp("gen-empty");
405 assert!(matches!(
406 generate(&dir, "!!!", ""),
407 Err(Error::BadRequest(_))
408 ));
409 fs::remove_dir_all(&dir).ok();
410 }
411
412 #[tokio::test]
413 async fn apply_creates_tracking_table_even_with_no_migrations() {
414 let db = Db::memory().await.unwrap();
415 let dir = tmp("apply-empty");
416 fs::create_dir_all(&dir).unwrap();
417 let applied = apply(&db, &dir).await.unwrap();
418 assert!(applied.is_empty());
419 let row = sqlx::query("SELECT COUNT(*) FROM rustio_migrations")
420 .fetch_one(db.pool())
421 .await
422 .unwrap();
423 let count: i64 = row.get(0);
424 assert_eq!(count, 0);
425 fs::remove_dir_all(&dir).ok();
426 }
427
428 #[tokio::test]
429 async fn apply_runs_pending_and_is_idempotent() {
430 let db = Db::memory().await.unwrap();
431 let dir = tmp("apply-idem");
432 fs::create_dir_all(&dir).unwrap();
433 fs::write(dir.join("0001_create.sql"), "CREATE TABLE t (id INTEGER);").unwrap();
434 fs::write(
435 dir.join("0002_insert.sql"),
436 "INSERT INTO t (id) VALUES (42);",
437 )
438 .unwrap();
439
440 let first = apply(&db, &dir).await.unwrap();
441 assert_eq!(first, vec!["0001_create.sql", "0002_insert.sql"]);
442
443 let second = apply(&db, &dir).await.unwrap();
444 assert!(second.is_empty());
445
446 let row = sqlx::query("SELECT id FROM t")
447 .fetch_one(db.pool())
448 .await
449 .unwrap();
450 let id: i64 = row.get(0);
451 assert_eq!(id, 42);
452
453 fs::remove_dir_all(&dir).ok();
454 }
455
456 #[tokio::test]
457 async fn apply_picks_up_new_migration_added_later() {
458 let db = Db::memory().await.unwrap();
459 let dir = tmp("apply-followup");
460 fs::create_dir_all(&dir).unwrap();
461 fs::write(
462 dir.join("0001_first.sql"),
463 "CREATE TABLE first (id INTEGER);",
464 )
465 .unwrap();
466 apply(&db, &dir).await.unwrap();
467
468 fs::write(
469 dir.join("0002_second.sql"),
470 "CREATE TABLE second (id INTEGER);",
471 )
472 .unwrap();
473 let applied = apply(&db, &dir).await.unwrap();
474 assert_eq!(applied, vec!["0002_second.sql"]);
475
476 sqlx::query("INSERT INTO first (id) VALUES (1)")
477 .execute(db.pool())
478 .await
479 .unwrap();
480 sqlx::query("INSERT INTO second (id) VALUES (2)")
481 .execute(db.pool())
482 .await
483 .unwrap();
484
485 fs::remove_dir_all(&dir).ok();
486 }
487
488 #[tokio::test]
489 async fn status_reports_applied_and_pending_separately() {
490 let db = Db::memory().await.unwrap();
491 let dir = tmp("status");
492 fs::create_dir_all(&dir).unwrap();
493 fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
494 fs::write(dir.join("0002_b.sql"), "CREATE TABLE b (id INTEGER);").unwrap();
495 fs::write(dir.join("0003_c.sql"), "CREATE TABLE c (id INTEGER);").unwrap();
496
497 fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
499 let applied_now = apply(&db, &dir).await.unwrap();
500 assert_eq!(applied_now.len(), 3);
501
502 fs::write(dir.join("0004_d.sql"), "CREATE TABLE d (id INTEGER);").unwrap();
504
505 let s = status(&db, &dir).await.unwrap();
506 assert_eq!(s.applied.len(), 3);
507 assert_eq!(
508 s.applied
509 .iter()
510 .map(|r| r.filename.as_str())
511 .collect::<Vec<_>>(),
512 vec!["0001_a.sql", "0002_b.sql", "0003_c.sql"]
513 );
514 assert_eq!(s.pending, vec!["0004_d.sql"]);
515
516 fs::remove_dir_all(&dir).ok();
517 }
518
519 #[tokio::test]
520 async fn status_on_empty_project_returns_empty_both() {
521 let db = Db::memory().await.unwrap();
522 let dir = tmp("status-empty");
523 fs::create_dir_all(&dir).unwrap();
524 let s = status(&db, &dir).await.unwrap();
525 assert!(s.applied.is_empty());
526 assert!(s.pending.is_empty());
527 fs::remove_dir_all(&dir).ok();
528 }
529
530 #[tokio::test]
531 async fn failed_migration_rolls_back_and_is_not_marked_applied() {
532 let db = Db::memory().await.unwrap();
533 let dir = tmp("apply-failure");
534 fs::create_dir_all(&dir).unwrap();
535 fs::write(dir.join("0001_ok.sql"), "CREATE TABLE ok (id INTEGER);").unwrap();
536 fs::write(dir.join("0002_bad.sql"), "CREATE TABLE ok (id INTEGER);").unwrap(); let result = apply(&db, &dir).await;
539 assert!(result.is_err());
540
541 let rows = sqlx::query("SELECT filename FROM rustio_migrations")
542 .fetch_all(db.pool())
543 .await
544 .unwrap();
545 let applied: Vec<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
546 assert_eq!(applied, vec!["0001_ok.sql"]);
547
548 fs::remove_dir_all(&dir).ok();
549 }
550}