1use std::collections::HashSet;
9use std::fs;
10use std::path::{Path, PathBuf};
11
12use sqlx::Row as _;
13
14use crate::error::Error;
15use crate::orm::Db;
16
17const TRACKING_TABLE: &str = "rustio_migrations";
18
19#[derive(Default, Clone, Copy, Debug)]
21pub struct ApplyOptions {
22 pub verbose: bool,
24}
25
26pub fn list(dir: &Path) -> Result<Vec<PathBuf>, Error> {
27 if !dir.exists() {
28 return Ok(Vec::new());
29 }
30 let entries = fs::read_dir(dir).map_err(|e| Error::Internal(e.to_string()))?;
31 let mut files: Vec<PathBuf> = entries
32 .filter_map(|e| e.ok())
33 .filter(|e| {
34 e.file_type().map(|t| t.is_file()).unwrap_or(false)
35 && e.path().extension().and_then(|s| s.to_str()) == Some("sql")
36 })
37 .map(|e| e.path())
38 .collect();
39 files.sort();
40 Ok(files)
41}
42
43pub fn generate(dir: &Path, name: &str, content: &str) -> Result<PathBuf, Error> {
44 let sanitized = sanitize_name(name);
45 if sanitized.is_empty() {
46 return Err(Error::BadRequest(
47 "migration name cannot be empty".to_string(),
48 ));
49 }
50 fs::create_dir_all(dir).map_err(|e| Error::Internal(e.to_string()))?;
51 let existing = list(dir)?;
52 let next = next_number(&existing);
53 let filename = format!("{:04}_{}.sql", next, sanitized);
54 let path = dir.join(filename);
55 fs::write(&path, content).map_err(|e| Error::Internal(e.to_string()))?;
56 Ok(path)
57}
58
59#[derive(Debug, Clone)]
60pub struct MigrationRecord {
61 pub filename: String,
62 pub applied_at: String,
63}
64
65#[derive(Debug)]
66pub struct Status {
67 pub applied: Vec<MigrationRecord>,
68 pub pending: Vec<String>,
69}
70
71pub async fn applied(db: &Db) -> Result<Vec<MigrationRecord>, Error> {
72 ensure_tracking_table(db).await?;
73 let rows = sqlx::query(&format!(
74 "SELECT filename, applied_at FROM {TRACKING_TABLE} ORDER BY filename"
75 ))
76 .fetch_all(db.pool())
77 .await?;
78 Ok(rows
79 .iter()
80 .map(|r| MigrationRecord {
81 filename: r.get(0),
82 applied_at: r.get(1),
83 })
84 .collect())
85}
86
87pub async fn status(db: &Db, dir: &Path) -> Result<Status, Error> {
88 let applied_records = applied(db).await?;
89 let applied_names: HashSet<String> =
90 applied_records.iter().map(|r| r.filename.clone()).collect();
91 let files = list(dir)?;
92 let pending: Vec<String> = files
93 .iter()
94 .filter_map(|p| p.file_name().and_then(|n| n.to_str()).map(String::from))
95 .filter(|n| !applied_names.contains(n))
96 .collect();
97 Ok(Status {
98 applied: applied_records,
99 pending,
100 })
101}
102
103pub async fn apply(db: &Db, dir: &Path) -> Result<Vec<String>, Error> {
104 apply_with(db, dir, ApplyOptions::default()).await
105}
106
107pub async fn apply_with(db: &Db, dir: &Path, opts: ApplyOptions) -> Result<Vec<String>, Error> {
108 crate::auth::ensure_core_tables(db).await?;
114 ensure_tracking_table(db).await?;
115
116 let rows = sqlx::query(&format!("SELECT filename FROM {TRACKING_TABLE}"))
117 .fetch_all(db.pool())
118 .await?;
119 let already_applied: HashSet<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
120
121 let files = list(dir)?;
122 let mut newly_applied = Vec::new();
123
124 for path in files {
125 let filename = match path.file_name().and_then(|n| n.to_str()) {
126 Some(n) => n.to_string(),
127 None => continue,
128 };
129 if already_applied.contains(&filename) {
130 continue;
131 }
132
133 let sql = fs::read_to_string(&path)
134 .map_err(|e| Error::Internal(format!("reading {filename}: {e}")))?;
135
136 if opts.verbose {
137 eprintln!("-- applying {filename}");
138 }
139
140 let mut tx = db.pool().begin().await?;
141 for stmt in split_sql(&sql) {
142 if opts.verbose {
143 eprintln!(" {}", stmt);
144 }
145 sqlx::query(&stmt)
146 .execute(&mut *tx)
147 .await
148 .map_err(|e| Error::Internal(format!("migration {filename} failed: {e}")))?;
149 }
150 sqlx::query(&format!(
151 "INSERT INTO {TRACKING_TABLE} (filename) VALUES (?)"
152 ))
153 .bind(&filename)
154 .execute(&mut *tx)
155 .await?;
156 tx.commit().await?;
157
158 newly_applied.push(filename);
159 }
160
161 Ok(newly_applied)
162}
163
164async fn ensure_tracking_table(db: &Db) -> Result<(), Error> {
165 db.execute(&format!(
166 "CREATE TABLE IF NOT EXISTS {TRACKING_TABLE} (
167 filename TEXT PRIMARY KEY,
168 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
169 )"
170 ))
171 .await
172}
173
174fn next_number(files: &[PathBuf]) -> u32 {
175 files
176 .iter()
177 .filter_map(|p| p.file_name()?.to_str())
178 .filter_map(|name| {
179 let (prefix, _) = name.split_once('_')?;
180 prefix.parse::<u32>().ok()
181 })
182 .max()
183 .map(|n| n + 1)
184 .unwrap_or(1)
185}
186
187fn sanitize_name(name: &str) -> String {
188 let mut out = String::new();
189 let mut last_sep = true;
190 for c in name.chars() {
191 if c.is_ascii_alphanumeric() {
192 for lc in c.to_lowercase() {
193 out.push(lc);
194 }
195 last_sep = false;
196 } else if !last_sep {
197 out.push('_');
198 last_sep = true;
199 }
200 }
201 out.trim_matches('_').to_string()
202}
203
204fn split_sql(sql: &str) -> Vec<String> {
211 let mut out = Vec::new();
212 let mut current = String::new();
213 let mut chars = sql.chars().peekable();
214
215 while let Some(c) = chars.next() {
216 match c {
217 '\'' => {
218 current.push(c);
219 loop {
220 match chars.next() {
221 Some('\'') => {
222 current.push('\'');
223 if chars.peek() == Some(&'\'') {
224 current.push(chars.next().unwrap());
225 continue;
226 }
227 break;
228 }
229 Some(other) => current.push(other),
230 None => break,
231 }
232 }
233 }
234 '-' if chars.peek() == Some(&'-') => {
235 current.push(c);
236 while let Some(&nc) = chars.peek() {
237 chars.next();
238 current.push(nc);
239 if nc == '\n' {
240 break;
241 }
242 }
243 }
244 '/' if chars.peek() == Some(&'*') => {
245 current.push(c);
246 current.push(chars.next().unwrap());
247 while let Some(c1) = chars.next() {
248 current.push(c1);
249 if c1 == '*' && chars.peek() == Some(&'/') {
250 current.push(chars.next().unwrap());
251 break;
252 }
253 }
254 }
255 ';' => {
256 let trimmed = current.trim();
257 if !trimmed.is_empty() {
258 out.push(trimmed.to_string());
259 }
260 current.clear();
261 }
262 _ => current.push(c),
263 }
264 }
265
266 let trimmed = current.trim();
267 if !trimmed.is_empty() {
268 out.push(trimmed.to_string());
269 }
270
271 out
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 fn tmp(prefix: &str) -> PathBuf {
279 let path = std::env::temp_dir().join(format!(
280 "rustio-mig-{prefix}-{}-{}",
281 std::process::id(),
282 std::time::SystemTime::now()
283 .duration_since(std::time::UNIX_EPOCH)
284 .unwrap()
285 .as_nanos()
286 ));
287 let _ = fs::remove_dir_all(&path);
288 path
289 }
290
291 #[test]
292 fn sanitize_lowercases_and_underscores() {
293 assert_eq!(sanitize_name("Add Blog Table"), "add_blog_table");
294 assert_eq!(sanitize_name("create-users-table"), "create_users_table");
295 assert_eq!(sanitize_name("add spaces"), "add_spaces");
296 assert_eq!(sanitize_name("CamelCase"), "camelcase");
297 }
298
299 #[test]
300 fn sanitize_trims_outer_separators() {
301 assert_eq!(sanitize_name("_add_"), "add");
302 assert_eq!(sanitize_name("--blog--"), "blog");
303 }
304
305 #[test]
306 fn sanitize_empty_returns_empty() {
307 assert_eq!(sanitize_name(""), "");
308 assert_eq!(sanitize_name(" "), "");
309 assert_eq!(sanitize_name("!!!"), "");
310 }
311
312 #[test]
313 fn next_number_starts_at_one() {
314 assert_eq!(next_number(&[]), 1);
315 }
316
317 #[test]
318 fn next_number_follows_highest() {
319 let files = vec![
320 PathBuf::from("migrations/0001_first.sql"),
321 PathBuf::from("migrations/0003_third.sql"),
322 PathBuf::from("migrations/0002_second.sql"),
323 ];
324 assert_eq!(next_number(&files), 4);
325 }
326
327 #[test]
328 fn next_number_ignores_non_numeric_prefixes() {
329 let files = vec![
330 PathBuf::from("migrations/readme.sql"),
331 PathBuf::from("migrations/0005_real.sql"),
332 ];
333 assert_eq!(next_number(&files), 6);
334 }
335
336 #[test]
337 fn split_sql_handles_multiple_statements() {
338 let sql = "CREATE TABLE a (id INT); CREATE TABLE b (id INT);";
339 let stmts = split_sql(sql);
340 assert_eq!(
341 stmts,
342 vec![
343 String::from("CREATE TABLE a (id INT)"),
344 String::from("CREATE TABLE b (id INT)"),
345 ]
346 );
347 }
348
349 #[test]
350 fn split_sql_ignores_empty_trailing() {
351 assert!(split_sql(";; ;").is_empty());
352 }
353
354 #[test]
355 fn split_sql_preserves_semicolon_inside_string_literal() {
356 assert_eq!(
357 split_sql("INSERT INTO t VALUES ('a;b'); CREATE TABLE x (id INT);"),
358 vec![
359 String::from("INSERT INTO t VALUES ('a;b')"),
360 String::from("CREATE TABLE x (id INT)"),
361 ]
362 );
363 }
364
365 #[test]
366 fn split_sql_handles_escaped_single_quote() {
367 assert_eq!(
368 split_sql("INSERT VALUES ('it''s; fine');"),
369 vec![String::from("INSERT VALUES ('it''s; fine')")]
370 );
371 }
372
373 #[test]
374 fn split_sql_skips_semicolons_inside_line_comment() {
375 assert_eq!(
376 split_sql("-- first; second\nCREATE TABLE t (id INT);"),
377 vec![String::from("-- first; second\nCREATE TABLE t (id INT)")]
378 );
379 }
380
381 #[test]
382 fn split_sql_skips_semicolons_inside_block_comment() {
383 assert_eq!(
384 split_sql("/* a;b;c */ CREATE TABLE t (id INT);"),
385 vec![String::from("/* a;b;c */ CREATE TABLE t (id INT)")]
386 );
387 }
388
389 #[test]
390 fn generate_creates_files_with_numbered_prefixes() {
391 let dir = tmp("gen");
392 let p1 = generate(&dir, "create users", "-- one").unwrap();
393 let p2 = generate(&dir, "add index", "-- two").unwrap();
394 assert!(p1
395 .file_name()
396 .unwrap()
397 .to_string_lossy()
398 .starts_with("0001_create_users"));
399 assert!(p2
400 .file_name()
401 .unwrap()
402 .to_string_lossy()
403 .starts_with("0002_add_index"));
404 assert_eq!(fs::read_to_string(&p1).unwrap(), "-- one");
405 fs::remove_dir_all(&dir).ok();
406 }
407
408 #[test]
409 fn generate_rejects_empty_name_after_sanitization() {
410 let dir = tmp("gen-empty");
411 assert!(matches!(
412 generate(&dir, "!!!", ""),
413 Err(Error::BadRequest(_))
414 ));
415 fs::remove_dir_all(&dir).ok();
416 }
417
418 #[tokio::test]
419 async fn apply_creates_tracking_table_even_with_no_migrations() {
420 let db = Db::memory().await.unwrap();
421 let dir = tmp("apply-empty");
422 fs::create_dir_all(&dir).unwrap();
423 let applied = apply(&db, &dir).await.unwrap();
424 assert!(applied.is_empty());
425 let row = sqlx::query("SELECT COUNT(*) FROM rustio_migrations")
426 .fetch_one(db.pool())
427 .await
428 .unwrap();
429 let count: i64 = row.get(0);
430 assert_eq!(count, 0);
431 fs::remove_dir_all(&dir).ok();
432 }
433
434 #[tokio::test]
435 async fn apply_runs_pending_and_is_idempotent() {
436 let db = Db::memory().await.unwrap();
437 let dir = tmp("apply-idem");
438 fs::create_dir_all(&dir).unwrap();
439 fs::write(dir.join("0001_create.sql"), "CREATE TABLE t (id INTEGER);").unwrap();
440 fs::write(
441 dir.join("0002_insert.sql"),
442 "INSERT INTO t (id) VALUES (42);",
443 )
444 .unwrap();
445
446 let first = apply(&db, &dir).await.unwrap();
447 assert_eq!(first, vec!["0001_create.sql", "0002_insert.sql"]);
448
449 let second = apply(&db, &dir).await.unwrap();
450 assert!(second.is_empty());
451
452 let row = sqlx::query("SELECT id FROM t")
453 .fetch_one(db.pool())
454 .await
455 .unwrap();
456 let id: i64 = row.get(0);
457 assert_eq!(id, 42);
458
459 fs::remove_dir_all(&dir).ok();
460 }
461
462 #[tokio::test]
463 async fn apply_picks_up_new_migration_added_later() {
464 let db = Db::memory().await.unwrap();
465 let dir = tmp("apply-followup");
466 fs::create_dir_all(&dir).unwrap();
467 fs::write(
468 dir.join("0001_first.sql"),
469 "CREATE TABLE first (id INTEGER);",
470 )
471 .unwrap();
472 apply(&db, &dir).await.unwrap();
473
474 fs::write(
475 dir.join("0002_second.sql"),
476 "CREATE TABLE second (id INTEGER);",
477 )
478 .unwrap();
479 let applied = apply(&db, &dir).await.unwrap();
480 assert_eq!(applied, vec!["0002_second.sql"]);
481
482 sqlx::query("INSERT INTO first (id) VALUES (1)")
483 .execute(db.pool())
484 .await
485 .unwrap();
486 sqlx::query("INSERT INTO second (id) VALUES (2)")
487 .execute(db.pool())
488 .await
489 .unwrap();
490
491 fs::remove_dir_all(&dir).ok();
492 }
493
494 #[tokio::test]
495 async fn status_reports_applied_and_pending_separately() {
496 let db = Db::memory().await.unwrap();
497 let dir = tmp("status");
498 fs::create_dir_all(&dir).unwrap();
499 fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
500 fs::write(dir.join("0002_b.sql"), "CREATE TABLE b (id INTEGER);").unwrap();
501 fs::write(dir.join("0003_c.sql"), "CREATE TABLE c (id INTEGER);").unwrap();
502
503 fs::write(dir.join("0001_a.sql"), "CREATE TABLE a (id INTEGER);").unwrap();
505 let applied_now = apply(&db, &dir).await.unwrap();
506 assert_eq!(applied_now.len(), 3);
507
508 fs::write(dir.join("0004_d.sql"), "CREATE TABLE d (id INTEGER);").unwrap();
510
511 let s = status(&db, &dir).await.unwrap();
512 assert_eq!(s.applied.len(), 3);
513 assert_eq!(
514 s.applied
515 .iter()
516 .map(|r| r.filename.as_str())
517 .collect::<Vec<_>>(),
518 vec!["0001_a.sql", "0002_b.sql", "0003_c.sql"]
519 );
520 assert_eq!(s.pending, vec!["0004_d.sql"]);
521
522 fs::remove_dir_all(&dir).ok();
523 }
524
525 #[tokio::test]
526 async fn status_on_empty_project_returns_empty_both() {
527 let db = Db::memory().await.unwrap();
528 let dir = tmp("status-empty");
529 fs::create_dir_all(&dir).unwrap();
530 let s = status(&db, &dir).await.unwrap();
531 assert!(s.applied.is_empty());
532 assert!(s.pending.is_empty());
533 fs::remove_dir_all(&dir).ok();
534 }
535
536 #[tokio::test]
537 async fn failed_migration_rolls_back_and_is_not_marked_applied() {
538 let db = Db::memory().await.unwrap();
539 let dir = tmp("apply-failure");
540 fs::create_dir_all(&dir).unwrap();
541 fs::write(dir.join("0001_ok.sql"), "CREATE TABLE ok (id INTEGER);").unwrap();
542 fs::write(dir.join("0002_bad.sql"), "CREATE TABLE ok (id INTEGER);").unwrap(); let result = apply(&db, &dir).await;
545 assert!(result.is_err());
546
547 let rows = sqlx::query("SELECT filename FROM rustio_migrations")
548 .fetch_all(db.pool())
549 .await
550 .unwrap();
551 let applied: Vec<String> = rows.iter().map(|r| r.get::<String, _>(0)).collect();
552 assert_eq!(applied, vec!["0001_ok.sql"]);
553
554 fs::remove_dir_all(&dir).ok();
555 }
556}