1#![allow(
2 clippy::struct_excessive_bools,
3 clippy::too_many_lines,
4 unused_imports,
5 dead_code,
6 unused_variables
7)]
8use crate::{db, prelude::*, DatabaseType, DEFAULT_MIGRATIONS_TABLE};
9use clap::Parser;
10use comfy_table::{Cell, CellAlignment, ContentArrangement, Table};
11use filetime::FileTime;
12use regex::Regex;
13use sqlx::{ConnectOptions, Database, Executor};
14use std::{fs, io, path::Path, process, str::FromStr, time::Duration};
15use time::{format_description, OffsetDateTime};
16use tracing_subscriber::{
17 fmt::format::FmtSpan, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt,
18 EnvFilter,
19};
20
21#[derive(Debug, clap::Parser)]
23pub struct Migrate {
24 #[clap(long, global(true))]
26 pub no_colors: bool,
27 #[clap(long, global(true))]
29 pub verbose: bool,
30 #[clap(long = "force", global(true))]
32 pub force: bool,
33 #[clap(long, alias = "no-verify-checksum", global(true))]
35 pub no_verify_checksums: bool,
36 #[clap(long, alias = "no-verify-name", global(true))]
38 pub no_verify_names: bool,
39 #[clap(long, global(true))]
41 pub no_env_file: bool,
42 #[clap(long, global(true))]
44 pub log_statements: bool,
45 #[clap(long, visible_alias = "db-url", global(true))]
47 pub database_url: Option<String>,
48 #[clap(long, default_value = DEFAULT_MIGRATIONS_TABLE, global(true))]
50 pub migrations_table: String,
51 #[clap(subcommand)]
52 pub operation: Operation,
53}
54
55#[derive(Debug, clap::Subcommand)]
57pub enum Operation {
58 #[clap(visible_aliases = &["up", "mig"])]
62 Migrate {
63 #[clap(long, conflicts_with = "version")]
66 name: Option<String>,
67
68 #[clap(long, conflicts_with = "name")]
71 version: Option<u64>,
72 },
73 #[clap(visible_aliases = &["down", "rev"])]
77 Revert {
78 #[clap(long, conflicts_with = "version")]
81 name: Option<String>,
82
83 #[clap(long, conflicts_with = "name")]
86 version: Option<u64>,
87 },
88 #[clap(visible_aliases = &["override"])]
93 Set {
94 #[clap(long, conflicts_with = "version", required_unless_present("version"))]
96 name: Option<String>,
97 #[clap(long, conflicts_with = "name", required_unless_present("name"))]
99 version: Option<u64>,
100 },
101 #[clap(visible_aliases = &["verify", "validate"])]
103 Check {},
104 #[clap(visible_aliases = &["list", "ls", "get"])]
106 Status {},
107 #[cfg(debug_assertions)]
111 #[clap(visible_aliases = &["new"])]
112 Add {
113 #[clap(long)]
115 sql: bool,
116 #[clap(long, short = 'r', visible_aliases = &["revert", "revertible"])]
118 reversible: bool,
119 #[clap(
123 long = "database",
124 visible_aliases = &["db"],
125 aliases = &["type"],
126 default_value = "postgres",
127 value_enum
128 )]
129 ty: DatabaseType,
130 name: String,
134 },
135}
136
137pub fn run<Db>(
151 migrations_path: impl AsRef<Path>,
152 migrations: impl IntoIterator<Item = Migration<Db>>,
153) where
154 Db: Database,
155 Db::Connection: db::Migrations,
156 for<'a> &'a mut Db::Connection: Executor<'a>,
157{
158 run_parsed(Migrate::parse(), migrations_path, migrations);
159}
160
161#[allow(clippy::missing_panics_doc)]
163pub fn run_parsed<Db>(
164 migrate: Migrate,
165 migrations_path: impl AsRef<Path>,
166 migrations: impl IntoIterator<Item = Migration<Db>>,
167) where
168 Db: Database,
169 Db::Connection: db::Migrations,
170 for<'a> &'a mut Db::Connection: Executor<'a>,
171{
172 setup_logging(&migrate);
173
174 if !migrate.no_env_file {
175 if let Ok(cwd) = std::env::current_dir() {
176 let env_path = cwd.join(".env");
177 if env_path.is_file() {
178 tracing::info!(path = ?env_path, ".env file found");
179 if let Err(err) = dotenvy::from_path(&env_path) {
180 tracing::warn!(path = ?env_path, error = %err, "failed to load .env file");
181 }
182 }
183 }
184 }
185
186 let migrations = migrations.into_iter().collect::<Vec<_>>();
187
188 tokio::runtime::Builder::new_current_thread()
189 .enable_all()
190 .build()
191 .unwrap()
192 .block_on(execute(migrate, migrations_path.as_ref(), migrations));
193}
194
195async fn execute<Db>(migrate: Migrate, migrations_path: &Path, migrations: Vec<Migration<Db>>)
196where
197 Db: Database,
198 Db::Connection: db::Migrations,
199 for<'a> &'a mut Db::Connection: Executor<'a>,
200{
201 match &migrate.operation {
202 Operation::Migrate { name, version } => {
203 let migrator = setup_migrator(&migrate, migrations).await;
204 do_migrate(&migrate, migrator, name.as_deref(), *version).await;
205 }
206 Operation::Revert { name, version } => {
207 let migrator = setup_migrator(&migrate, migrations).await;
208 revert(&migrate, migrator, name.as_deref(), *version).await;
209 }
210 Operation::Set { name, version } => {
211 let migrator = setup_migrator(&migrate, migrations).await;
212 force(&migrate, migrator, name.as_deref(), *version).await;
213 }
214 Operation::Check {} => {
215 let migrator = setup_migrator(&migrate, migrations).await;
216 check(&migrate, migrator).await;
217 }
218 Operation::Status {} => {
219 let migrator = setup_migrator(&migrate, migrations).await;
220 log_status(&migrate, migrator).await;
221 }
222 #[cfg(debug_assertions)]
223 Operation::Add {
224 sql,
225 reversible,
226 name,
227 ty,
228 } => add(&migrate, migrations_path, *sql, *reversible, name, *ty),
229 }
230}
231
232async fn check<Db>(_migrate: &Migrate, migrator: Migrator<Db>)
233where
234 Db: Database,
235 Db::Connection: db::Migrations,
236 for<'a> &'a mut Db::Connection: Executor<'a>,
237{
238 match migrator.verify().await {
239 Ok(_) => {
240 tracing::info!("No issues found");
241 }
242 Err(err) => {
243 tracing::error!(error = %err, "error verifying migrations");
244 process::exit(1);
245 }
246 }
247}
248
249#[cfg(debug_assertions)]
250fn add(
251 _migrate: &Migrate,
252 migrations_path: &Path,
253 sql: bool,
254 reversible: bool,
255 name: &str,
256 ty: DatabaseType,
257) {
258 let now = OffsetDateTime::now_utc();
259
260 let now_formatted = now
261 .format(&format_description::parse("[year][month][day][hour][minute][second]").unwrap())
262 .unwrap();
263
264 if !migrations_path.is_dir() {
265 tracing::error!("migrations path must be a directory");
266 process::exit(1);
267 }
268
269 let re = Regex::new("[A-Za-z_][A-Za-z_0-9]*").unwrap();
270
271 if !re.is_match(name) {
272 tracing::error!(name, "invalid migration name");
273 process::exit(1);
274 }
275
276 if sql {
277 let up_filename = format!("{}_{}.migrate.sql", &now_formatted, name);
278
279 if let Err(error) = fs::write(
280 migrations_path.join(&up_filename),
281 format!(
282 r#"-- Migration SQL for {name}
283"#,
284 ),
285 ) {
286 tracing::error!(error = %error, path = ?migrations_path.join(&up_filename), "failed to write file");
287 process::exit(1);
288 }
289
290 if reversible {
291 let down_filename = format!("{}_{}.revert.sql", &now_formatted, name);
292 if let Err(error) = fs::write(
293 migrations_path.join(&down_filename),
294 format!(
295 r#"-- Revert SQL for {name}
296"#,
297 ),
298 ) {
299 tracing::error!(error = %error, path = ?migrations_path.join(&down_filename), "failed to write file");
300 process::exit(1);
301 }
302 }
303
304 tracing::info!(name, "added migration");
305 } else {
306 let up_filename = format!("{}_{}.migrate.rs", &now_formatted, name);
307
308 let sqlx_type = ty.sqlx_type();
309
310 if let Err(error) = fs::write(
311 migrations_path.join(&up_filename),
312 format!(
313 r#"use sqlx::{sqlx_type};
314use sqlx_migrate::prelude::*;
315
316/// Executes migration `{name}` in the given migration context.
317//
318// Do not modify the function name.
319// Do not modify the signature with the exception of the SQLx database type.
320pub async fn {name}(ctx: &mut MigrationContext<{sqlx_type}>) -> Result<(), MigrationError> {{
321 // write your migration operations here
322 todo!()
323}}
324"#,
325 ),
326 ) {
327 tracing::error!(error = %error, path = ?migrations_path.join(&up_filename), "failed to write file");
328 process::exit(1);
329 }
330
331 if reversible {
332 let down_filename = format!("{}_{}.revert.rs", &now_formatted, name);
333
334 if let Err(error) = fs::write(
335 migrations_path.join(&down_filename),
336 format!(
337 r#"use sqlx::{sqlx_type};
338use sqlx_migrate::prelude::*;
339
340/// Reverts migration `{name}` in the given migration context.
341//
342// Do not modify the function name.
343// Do not modify the signature with the exception of the SQLx database type.
344pub async fn revert_{name}(ctx: &mut MigrationContext<{sqlx_type}>) -> Result<(), MigrationError> {{
345 // write your revert operations here
346 todo!()
347}}
348"#,
349 ),
350 ) {
351 tracing::error!(error = %error, path = ?migrations_path.join(&down_filename), "failed to write file");
352 process::exit(1);
353 }
354 }
355 }
356
357 if let Err(err) = filetime::set_file_mtime(migrations_path, FileTime::now()) {
358 tracing::debug!(error = %err, "error updating the migrations directory");
359 }
360}
361
362async fn do_migrate<Db>(
363 _migrate: &Migrate,
364 migrator: Migrator<Db>,
365 name: Option<&str>,
366 version: Option<u64>,
367) where
368 Db: Database,
369 Db::Connection: db::Migrations,
370 for<'a> &'a mut Db::Connection: Executor<'a>,
371{
372 let version = match version {
373 Some(v) => Some(v),
374 None => match name {
375 Some(name) => {
376 if let Some((idx, _)) = migrator
377 .local_migrations()
378 .iter()
379 .enumerate()
380 .find(|mig| mig.1.name() == name)
381 {
382 Some(idx as u64 + 1)
383 } else {
384 tracing::error!(name = name, "migration not found");
385 process::exit(1);
386 }
387 }
388 None => None,
389 },
390 };
391
392 match version {
393 Some(version) => match migrator.migrate(version).await {
394 Ok(s) => print_summary(&s),
395 Err(error) => {
396 tracing::error!(error = %error, "error applying migrations");
397 process::exit(1);
398 }
399 },
400 None => match migrator.migrate_all().await {
401 Ok(s) => print_summary(&s),
402 Err(error) => {
403 tracing::error!(error = %error, "error applying migrations");
404 process::exit(1);
405 }
406 },
407 }
408}
409
410async fn revert<Db>(
411 migrate: &Migrate,
412 migrator: Migrator<Db>,
413 name: Option<&str>,
414 version: Option<u64>,
415) where
416 Db: Database,
417 Db::Connection: db::Migrations,
418 for<'a> &'a mut Db::Connection: Executor<'a>,
419{
420 if !migrate.force {
421 tracing::error!("the `--force` flag is required for this operation");
422 process::exit(1);
423 }
424
425 let version = match version {
426 Some(v) => Some(v),
427 None => match name {
428 Some(name) => {
429 if let Some((idx, _)) = migrator
430 .local_migrations()
431 .iter()
432 .enumerate()
433 .find(|mig| mig.1.name() == name)
434 {
435 Some(idx as u64 + 1)
436 } else {
437 tracing::error!(name = name, "migration not found");
438 process::exit(1);
439 }
440 }
441 None => None,
442 },
443 };
444
445 match version {
446 Some(version) => match migrator.revert(version).await {
447 Ok(s) => print_summary(&s),
448 Err(error) => {
449 tracing::error!(error = %error, "error reverting migrations");
450 process::exit(1);
451 }
452 },
453 None => match migrator.revert_all().await {
454 Ok(s) => print_summary(&s),
455 Err(error) => {
456 tracing::error!(error = %error, "error reverting migrations");
457 process::exit(1);
458 }
459 },
460 }
461}
462
463async fn force<Db>(
464 migrate: &Migrate,
465 migrator: Migrator<Db>,
466 name: Option<&str>,
467 version: Option<u64>,
468) where
469 Db: Database,
470 Db::Connection: db::Migrations,
471 for<'a> &'a mut Db::Connection: Executor<'a>,
472{
473 if !migrate.force {
474 tracing::error!("the `--do-as-i-say` or `--force` flag is required for this operation");
475 process::exit(1);
476 }
477
478 let version = match version {
479 Some(v) => v,
480 None => {
481 if let Some((idx, _)) = migrator
482 .local_migrations()
483 .iter()
484 .enumerate()
485 .find(|mig| mig.1.name() == name.unwrap())
486 {
487 idx as u64 + 1
488 } else {
489 tracing::error!(name = name.unwrap(), "migration not found");
490 process::exit(1);
491 }
492 }
493 };
494
495 match migrator.force_version(version).await {
496 Ok(s) => print_summary(&s),
497 Err(error) => {
498 tracing::error!(error = %error, "error updating migrations");
499 process::exit(1);
500 }
501 }
502}
503
504async fn log_status<Db>(_migrate: &Migrate, migrator: Migrator<Db>)
505where
506 Db: Database,
507 Db::Connection: db::Migrations,
508 for<'a> &'a mut Db::Connection: Executor<'a>,
509{
510 fn mig_ok(status: &MigrationStatus) -> bool {
511 if status.missing_local {
512 return false;
513 }
514
515 match &status.applied {
516 Some(applied) => {
517 status.checksum_ok
518 && status.name == applied.name
519 && status.version == applied.version
520 }
521 None => true,
522 }
523 }
524
525 let status = match migrator.status().await {
526 Ok(s) => s,
527 Err(error) => {
528 tracing::error!(error = %error, "error retrieving migration status");
529 process::exit(1);
530 }
531 };
532
533 let all_valid = status.iter().all(mig_ok);
534
535 let mut table = Table::new();
536
537 table
538 .set_content_arrangement(ContentArrangement::Dynamic)
539 .set_header(Vec::from([
540 Cell::new("Version").set_alignment(CellAlignment::Center),
541 Cell::new("Name").set_alignment(CellAlignment::Center),
542 Cell::new("Applied").set_alignment(CellAlignment::Center),
543 Cell::new("Valid").set_alignment(CellAlignment::Center),
544 Cell::new("Revertible").set_alignment(CellAlignment::Center),
545 ]));
546
547 for mig in status {
548 let ok = mig_ok(&mig);
549
550 table.add_row(Vec::from([
551 Cell::new(mig.version.to_string().as_str()).set_alignment(CellAlignment::Center),
552 Cell::new(&mig.name).set_alignment(CellAlignment::Center),
553 Cell::new(if mig.applied.is_some() { "x" } else { "" })
554 .set_alignment(CellAlignment::Center),
555 Cell::new(if ok { "x" } else { "INVALID" }).set_alignment(CellAlignment::Center),
556 Cell::new(if mig.reversible { "x" } else { "" }).set_alignment(CellAlignment::Center),
557 ]));
558 }
559
560 println!("{}", table);
561
562 if !all_valid {
563 process::exit(1);
564 }
565}
566
567fn print_summary(summary: &MigrationSummary) {
568 let mut table = Table::new();
569
570 table
571 .set_content_arrangement(ContentArrangement::Dynamic)
572 .set_header(Vec::from([
573 Cell::new("Old Version").set_alignment(CellAlignment::Center),
574 Cell::new("New Version").set_alignment(CellAlignment::Center),
575 Cell::new("Applied Migrations").set_alignment(CellAlignment::Center),
576 Cell::new("Reverted Migrations").set_alignment(CellAlignment::Center),
577 ]));
578
579 let mut s = Vec::<Cell>::new();
580
581 s.push(match summary.old_version {
582 Some(v) => Cell::new(v.to_string()).set_alignment(CellAlignment::Center),
583 None => "".into(),
584 });
585
586 s.push(match summary.new_version {
587 Some(v) => Cell::new(v.to_string()).set_alignment(CellAlignment::Center),
588 None => "".into(),
589 });
590
591 s.push(match (summary.old_version, summary.new_version) {
592 (Some(old), Some(new)) => {
593 if new >= old {
594 Cell::new((new - old).to_string()).set_alignment(CellAlignment::Center)
595 } else {
596 Cell::new("0").set_alignment(CellAlignment::Center)
597 }
598 }
599 (None, Some(new)) => Cell::new(new.to_string()).set_alignment(CellAlignment::Center),
600 (_, None) => Cell::new("0").set_alignment(CellAlignment::Center),
601 });
602
603 s.push(match (summary.old_version, summary.new_version) {
604 (Some(old), Some(new)) => {
605 if new <= old {
606 Cell::new((old - new).to_string()).set_alignment(CellAlignment::Center)
607 } else {
608 Cell::new("0").set_alignment(CellAlignment::Center)
609 }
610 }
611 (Some(old), None) => Cell::new(old.to_string()).set_alignment(CellAlignment::Center),
612 (None, _) => Cell::new("0").set_alignment(CellAlignment::Center),
613 });
614
615 table.add_row(s);
616
617 eprintln!("{table}");
618}
619
620async fn setup_migrator<Db>(migrate: &Migrate, migrations: Vec<Migration<Db>>) -> Migrator<Db>
621where
622 Db: Database,
623 Db::Connection: db::Migrations,
624 for<'a> &'a mut Db::Connection: Executor<'a>,
625{
626 let db_url = match &migrate.database_url {
627 Some(s) => s.clone(),
628 None => {
629 if let Ok(url) = std::env::var("DATABASE_URL") {
630 url
631 } else {
632 tracing::error!(
633 "`DATABASE_URL` environment variable or `--database-url` argument is required"
634 );
635 process::exit(1);
636 }
637 }
638 };
639
640 let mut options =
641 match db_url.parse::<<<Db as Database>::Connection as sqlx::Connection>::Options>() {
642 Ok(opts) => opts,
643 Err(err) => {
644 tracing::error!(error = %err, "invalid database URL");
645 process::exit(1);
646 }
647 };
648
649 if migrate.log_statements {
650 options = options
651 .log_statements("INFO".parse().unwrap())
652 .log_slow_statements("WARN".parse().unwrap(), Duration::from_secs(1));
653 } else {
654 options = options.disable_statement_logging();
655 }
656
657 match Migrator::connect_with(&options).await {
658 Ok(mut mig) => {
659 mig.set_options(MigratorOptions {
660 verify_checksums: !migrate.no_verify_checksums,
661 verify_names: !migrate.no_verify_names,
662 });
663
664 if !migrate.migrations_table.is_empty() {
665 mig.set_migrations_table(&migrate.migrations_table);
666 }
667
668 mig.add_migrations(migrations);
669
670 mig
671 }
672 Err(err) => {
673 tracing::error!(error = %err, "failed to create database connection");
674 process::exit(1);
675 }
676 }
677}
678
679fn setup_logging(migrate: &Migrate) {
680 let format = tracing_subscriber::fmt::format().with_ansi(colors(migrate));
681
682 let verbose = migrate.verbose;
683
684 let span_events = if verbose {
685 FmtSpan::NEW | FmtSpan::CLOSE
686 } else {
687 FmtSpan::CLOSE
688 };
689
690 let registry = tracing_subscriber::registry();
691
692 let env_filter = match EnvFilter::try_from_default_env() {
693 Ok(f) => f,
694 Err(_) => EnvFilter::default()
695 .add_directive(tracing::Level::INFO.into())
696 .add_directive("sqlx::postgres::notice=error".parse().unwrap()),
697 };
698
699 if verbose {
700 registry
701 .with(env_filter)
702 .with(
703 tracing_subscriber::fmt::layer()
704 .with_writer(io::stderr)
705 .with_span_events(span_events)
706 .event_format(format.pretty()),
707 )
708 .init();
709 } else {
710 registry
711 .with(env_filter)
712 .with(
713 tracing_subscriber::fmt::layer()
714 .with_writer(io::stderr)
715 .with_span_events(span_events)
716 .event_format(format),
717 )
718 .init();
719 }
720}
721
722fn colors(matches: &Migrate) -> bool {
723 if matches.no_colors {
724 return false;
725 }
726
727 atty::is(atty::Stream::Stdout)
728}