sqlx_migrate/
cli.rs

1#![allow(
2    clippy::struct_excessive_bools,
3    clippy::too_many_lines,
4    unused_imports,
5    dead_code,
6    unused_variables
7)]
8use crate::{db, prelude::*, DatabaseType, DEFAULT_MIGRATIONS_TABLE};
9use clap::Parser;
10use comfy_table::{Cell, CellAlignment, ContentArrangement, Table};
11use filetime::FileTime;
12use regex::Regex;
13use sqlx::{ConnectOptions, Database, Executor};
14use std::{fs, io, path::Path, process, str::FromStr, time::Duration};
15use time::{format_description, OffsetDateTime};
16use tracing_subscriber::{
17    fmt::format::FmtSpan, prelude::__tracing_subscriber_SubscriberExt, util::SubscriberInitExt,
18    EnvFilter,
19};
20
21/// Command-line arguments.
22#[derive(Debug, clap::Parser)]
23pub struct Migrate {
24    /// Disable colors in messages.
25    #[clap(long, global(true))]
26    pub no_colors: bool,
27    /// Enable the logging of tracing spans.
28    #[clap(long, global(true))]
29    pub verbose: bool,
30    /// Force the operation, required for some actions.
31    #[clap(long = "force", global(true))]
32    pub force: bool,
33    /// Skip verifying migration checksums.
34    #[clap(long, alias = "no-verify-checksum", global(true))]
35    pub no_verify_checksums: bool,
36    /// Skip verifying migration names.
37    #[clap(long, alias = "no-verify-name", global(true))]
38    pub no_verify_names: bool,
39    /// Skip loading .env files.
40    #[clap(long, global(true))]
41    pub no_env_file: bool,
42    /// Log all SQL statements.
43    #[clap(long, global(true))]
44    pub log_statements: bool,
45    /// Database URL, if not given the `DATABASE_URL` environment variable will be used.
46    #[clap(long, visible_alias = "db-url", global(true))]
47    pub database_url: Option<String>,
48    /// The name of the migrations table.
49    #[clap(long, default_value = DEFAULT_MIGRATIONS_TABLE, global(true))]
50    pub migrations_table: String,
51    #[clap(subcommand)]
52    pub operation: Operation,
53}
54
55/// A command-line operation.
56#[derive(Debug, clap::Subcommand)]
57pub enum Operation {
58    /// Apply all migrations up to and including the given migration.
59    ///
60    /// If no migration is given, all migrations are applied.
61    #[clap(visible_aliases = &["up", "mig"])]
62    Migrate {
63        /// Apply all migrations up to and including the migration
64        /// with the given name.
65        #[clap(long, conflicts_with = "version")]
66        name: Option<String>,
67
68        /// Apply all migrations up to and including the migration
69        /// with the given version.
70        #[clap(long, conflicts_with = "name")]
71        version: Option<u64>,
72    },
73    /// Revert the given migration and all subsequent ones.
74    ///
75    /// If no migration is set, all applied migrations are reverted.
76    #[clap(visible_aliases = &["down", "rev"])]
77    Revert {
78        /// Revert all migrations after and including the migration
79        /// with the given name.
80        #[clap(long, conflicts_with = "version")]
81        name: Option<String>,
82
83        /// Revert all migrations after and including the migration
84        /// the given version.
85        #[clap(long, conflicts_with = "name")]
86        version: Option<u64>,
87    },
88    /// Forcibly set a given migration.
89    ///
90    /// This does not apply nor revert any migrations, and
91    /// only overrides migration status.
92    #[clap(visible_aliases = &["override"])]
93    Set {
94        /// Forcibly set the migration with the given name.
95        #[clap(long, conflicts_with = "version", required_unless_present("version"))]
96        name: Option<String>,
97        /// Forcibly set the migration with the given version.
98        #[clap(long, conflicts_with = "name", required_unless_present("name"))]
99        version: Option<u64>,
100    },
101    /// Verify migrations and print errors.
102    #[clap(visible_aliases = &["verify", "validate"])]
103    Check {},
104    /// List all migrations.
105    #[clap(visible_aliases = &["list", "ls", "get"])]
106    Status {},
107    /// Add a new migration.
108    ///
109    /// The migrations default to Rust files.
110    #[cfg(debug_assertions)]
111    #[clap(visible_aliases = &["new"])]
112    Add {
113        /// Use SQL for the migrations.
114        #[clap(long)]
115        sql: bool,
116        /// Create a "revert" or "down" migration.
117        #[clap(long, short = 'r', visible_aliases = &["revert", "revertible"])]
118        reversible: bool,
119        /// The SQLx type of the database in Rust migrations.
120        ///
121        /// By default, all migrations will be using `Any`.
122        #[clap(
123            long = "database",
124            visible_aliases = &["db"],
125            aliases = &["type"],
126            default_value = "postgres",
127            value_enum
128        )]
129        ty: DatabaseType,
130        /// The name of the migration.
131        ///
132        /// It must be across all migrations.
133        name: String,
134    },
135}
136
137/// Run a CLI application that provides operations with the
138/// given migrations.
139///
140/// When compiled with `debug_assertions`, it additionally allows modifying migrations
141/// at the given `migrations_path`.
142///
143/// Although not required, `migrations` are expected to be originated from `migrations_path`.
144///
145/// # Panics
146///
147/// This functon assumes that it has control over the entire application.
148///
149/// It will happily alter global state (tracing), panic, or terminate the process.
150pub fn run<Db>(
151    migrations_path: impl AsRef<Path>,
152    migrations: impl IntoIterator<Item = Migration<Db>>,
153) where
154    Db: Database,
155    Db::Connection: db::Migrations,
156    for<'a> &'a mut Db::Connection: Executor<'a>,
157{
158    run_parsed(Migrate::parse(), migrations_path, migrations);
159}
160
161/// Same as [`run`], but allows for parsing and inspecting [`Migrate`] beforehand.
162#[allow(clippy::missing_panics_doc)]
163pub fn run_parsed<Db>(
164    migrate: Migrate,
165    migrations_path: impl AsRef<Path>,
166    migrations: impl IntoIterator<Item = Migration<Db>>,
167) where
168    Db: Database,
169    Db::Connection: db::Migrations,
170    for<'a> &'a mut Db::Connection: Executor<'a>,
171{
172    setup_logging(&migrate);
173
174    if !migrate.no_env_file {
175        if let Ok(cwd) = std::env::current_dir() {
176            let env_path = cwd.join(".env");
177            if env_path.is_file() {
178                tracing::info!(path = ?env_path, ".env file found");
179                if let Err(err) = dotenvy::from_path(&env_path) {
180                    tracing::warn!(path = ?env_path, error = %err, "failed to load .env file");
181                }
182            }
183        }
184    }
185
186    let migrations = migrations.into_iter().collect::<Vec<_>>();
187
188    tokio::runtime::Builder::new_current_thread()
189        .enable_all()
190        .build()
191        .unwrap()
192        .block_on(execute(migrate, migrations_path.as_ref(), migrations));
193}
194
195async fn execute<Db>(migrate: Migrate, migrations_path: &Path, migrations: Vec<Migration<Db>>)
196where
197    Db: Database,
198    Db::Connection: db::Migrations,
199    for<'a> &'a mut Db::Connection: Executor<'a>,
200{
201    match &migrate.operation {
202        Operation::Migrate { name, version } => {
203            let migrator = setup_migrator(&migrate, migrations).await;
204            do_migrate(&migrate, migrator, name.as_deref(), *version).await;
205        }
206        Operation::Revert { name, version } => {
207            let migrator = setup_migrator(&migrate, migrations).await;
208            revert(&migrate, migrator, name.as_deref(), *version).await;
209        }
210        Operation::Set { name, version } => {
211            let migrator = setup_migrator(&migrate, migrations).await;
212            force(&migrate, migrator, name.as_deref(), *version).await;
213        }
214        Operation::Check {} => {
215            let migrator = setup_migrator(&migrate, migrations).await;
216            check(&migrate, migrator).await;
217        }
218        Operation::Status {} => {
219            let migrator = setup_migrator(&migrate, migrations).await;
220            log_status(&migrate, migrator).await;
221        }
222        #[cfg(debug_assertions)]
223        Operation::Add {
224            sql,
225            reversible,
226            name,
227            ty,
228        } => add(&migrate, migrations_path, *sql, *reversible, name, *ty),
229    }
230}
231
232async fn check<Db>(_migrate: &Migrate, migrator: Migrator<Db>)
233where
234    Db: Database,
235    Db::Connection: db::Migrations,
236    for<'a> &'a mut Db::Connection: Executor<'a>,
237{
238    match migrator.verify().await {
239        Ok(_) => {
240            tracing::info!("No issues found");
241        }
242        Err(err) => {
243            tracing::error!(error = %err, "error verifying migrations");
244            process::exit(1);
245        }
246    }
247}
248
249#[cfg(debug_assertions)]
250fn add(
251    _migrate: &Migrate,
252    migrations_path: &Path,
253    sql: bool,
254    reversible: bool,
255    name: &str,
256    ty: DatabaseType,
257) {
258    let now = OffsetDateTime::now_utc();
259
260    let now_formatted = now
261        .format(&format_description::parse("[year][month][day][hour][minute][second]").unwrap())
262        .unwrap();
263
264    if !migrations_path.is_dir() {
265        tracing::error!("migrations path must be a directory");
266        process::exit(1);
267    }
268
269    let re = Regex::new("[A-Za-z_][A-Za-z_0-9]*").unwrap();
270
271    if !re.is_match(name) {
272        tracing::error!(name, "invalid migration name");
273        process::exit(1);
274    }
275
276    if sql {
277        let up_filename = format!("{}_{}.migrate.sql", &now_formatted, name);
278
279        if let Err(error) = fs::write(
280            migrations_path.join(&up_filename),
281            format!(
282                r#"-- Migration SQL for {name}
283"#,
284            ),
285        ) {
286            tracing::error!(error = %error, path = ?migrations_path.join(&up_filename), "failed to write file");
287            process::exit(1);
288        }
289
290        if reversible {
291            let down_filename = format!("{}_{}.revert.sql", &now_formatted, name);
292            if let Err(error) = fs::write(
293                migrations_path.join(&down_filename),
294                format!(
295                    r#"-- Revert SQL for {name}
296"#,
297                ),
298            ) {
299                tracing::error!(error = %error, path = ?migrations_path.join(&down_filename), "failed to write file");
300                process::exit(1);
301            }
302        }
303
304        tracing::info!(name, "added migration");
305    } else {
306        let up_filename = format!("{}_{}.migrate.rs", &now_formatted, name);
307
308        let sqlx_type = ty.sqlx_type();
309
310        if let Err(error) = fs::write(
311            migrations_path.join(&up_filename),
312            format!(
313                r#"use sqlx::{sqlx_type};
314use sqlx_migrate::prelude::*;
315
316/// Executes migration `{name}` in the given migration context.
317//
318// Do not modify the function name.
319// Do not modify the signature with the exception of the SQLx database type.
320pub async fn {name}(ctx: &mut MigrationContext<{sqlx_type}>) -> Result<(), MigrationError> {{
321    // write your migration operations here
322    todo!()
323}}
324"#,
325            ),
326        ) {
327            tracing::error!(error = %error, path = ?migrations_path.join(&up_filename), "failed to write file");
328            process::exit(1);
329        }
330
331        if reversible {
332            let down_filename = format!("{}_{}.revert.rs", &now_formatted, name);
333
334            if let Err(error) = fs::write(
335                migrations_path.join(&down_filename),
336                format!(
337                    r#"use sqlx::{sqlx_type};
338use sqlx_migrate::prelude::*;
339
340/// Reverts migration `{name}` in the given migration context.
341//
342// Do not modify the function name.
343// Do not modify the signature with the exception of the SQLx database type.
344pub async fn revert_{name}(ctx: &mut MigrationContext<{sqlx_type}>) -> Result<(), MigrationError> {{
345    // write your revert operations here
346    todo!()
347}}
348"#,
349                ),
350            ) {
351                tracing::error!(error = %error, path = ?migrations_path.join(&down_filename), "failed to write file");
352                process::exit(1);
353            }
354        }
355    }
356
357    if let Err(err) = filetime::set_file_mtime(migrations_path, FileTime::now()) {
358        tracing::debug!(error = %err, "error updating the migrations directory");
359    }
360}
361
362async fn do_migrate<Db>(
363    _migrate: &Migrate,
364    migrator: Migrator<Db>,
365    name: Option<&str>,
366    version: Option<u64>,
367) where
368    Db: Database,
369    Db::Connection: db::Migrations,
370    for<'a> &'a mut Db::Connection: Executor<'a>,
371{
372    let version = match version {
373        Some(v) => Some(v),
374        None => match name {
375            Some(name) => {
376                if let Some((idx, _)) = migrator
377                    .local_migrations()
378                    .iter()
379                    .enumerate()
380                    .find(|mig| mig.1.name() == name)
381                {
382                    Some(idx as u64 + 1)
383                } else {
384                    tracing::error!(name = name, "migration not found");
385                    process::exit(1);
386                }
387            }
388            None => None,
389        },
390    };
391
392    match version {
393        Some(version) => match migrator.migrate(version).await {
394            Ok(s) => print_summary(&s),
395            Err(error) => {
396                tracing::error!(error = %error, "error applying migrations");
397                process::exit(1);
398            }
399        },
400        None => match migrator.migrate_all().await {
401            Ok(s) => print_summary(&s),
402            Err(error) => {
403                tracing::error!(error = %error, "error applying migrations");
404                process::exit(1);
405            }
406        },
407    }
408}
409
410async fn revert<Db>(
411    migrate: &Migrate,
412    migrator: Migrator<Db>,
413    name: Option<&str>,
414    version: Option<u64>,
415) where
416    Db: Database,
417    Db::Connection: db::Migrations,
418    for<'a> &'a mut Db::Connection: Executor<'a>,
419{
420    if !migrate.force {
421        tracing::error!("the `--force` flag is required for this operation");
422        process::exit(1);
423    }
424
425    let version = match version {
426        Some(v) => Some(v),
427        None => match name {
428            Some(name) => {
429                if let Some((idx, _)) = migrator
430                    .local_migrations()
431                    .iter()
432                    .enumerate()
433                    .find(|mig| mig.1.name() == name)
434                {
435                    Some(idx as u64 + 1)
436                } else {
437                    tracing::error!(name = name, "migration not found");
438                    process::exit(1);
439                }
440            }
441            None => None,
442        },
443    };
444
445    match version {
446        Some(version) => match migrator.revert(version).await {
447            Ok(s) => print_summary(&s),
448            Err(error) => {
449                tracing::error!(error = %error, "error reverting migrations");
450                process::exit(1);
451            }
452        },
453        None => match migrator.revert_all().await {
454            Ok(s) => print_summary(&s),
455            Err(error) => {
456                tracing::error!(error = %error, "error reverting migrations");
457                process::exit(1);
458            }
459        },
460    }
461}
462
463async fn force<Db>(
464    migrate: &Migrate,
465    migrator: Migrator<Db>,
466    name: Option<&str>,
467    version: Option<u64>,
468) where
469    Db: Database,
470    Db::Connection: db::Migrations,
471    for<'a> &'a mut Db::Connection: Executor<'a>,
472{
473    if !migrate.force {
474        tracing::error!("the `--do-as-i-say` or `--force` flag is required for this operation");
475        process::exit(1);
476    }
477
478    let version = match version {
479        Some(v) => v,
480        None => {
481            if let Some((idx, _)) = migrator
482                .local_migrations()
483                .iter()
484                .enumerate()
485                .find(|mig| mig.1.name() == name.unwrap())
486            {
487                idx as u64 + 1
488            } else {
489                tracing::error!(name = name.unwrap(), "migration not found");
490                process::exit(1);
491            }
492        }
493    };
494
495    match migrator.force_version(version).await {
496        Ok(s) => print_summary(&s),
497        Err(error) => {
498            tracing::error!(error = %error, "error updating migrations");
499            process::exit(1);
500        }
501    }
502}
503
504async fn log_status<Db>(_migrate: &Migrate, migrator: Migrator<Db>)
505where
506    Db: Database,
507    Db::Connection: db::Migrations,
508    for<'a> &'a mut Db::Connection: Executor<'a>,
509{
510    fn mig_ok(status: &MigrationStatus) -> bool {
511        if status.missing_local {
512            return false;
513        }
514
515        match &status.applied {
516            Some(applied) => {
517                status.checksum_ok
518                    && status.name == applied.name
519                    && status.version == applied.version
520            }
521            None => true,
522        }
523    }
524
525    let status = match migrator.status().await {
526        Ok(s) => s,
527        Err(error) => {
528            tracing::error!(error = %error, "error retrieving migration status");
529            process::exit(1);
530        }
531    };
532
533    let all_valid = status.iter().all(mig_ok);
534
535    let mut table = Table::new();
536
537    table
538        .set_content_arrangement(ContentArrangement::Dynamic)
539        .set_header(Vec::from([
540            Cell::new("Version").set_alignment(CellAlignment::Center),
541            Cell::new("Name").set_alignment(CellAlignment::Center),
542            Cell::new("Applied").set_alignment(CellAlignment::Center),
543            Cell::new("Valid").set_alignment(CellAlignment::Center),
544            Cell::new("Revertible").set_alignment(CellAlignment::Center),
545        ]));
546
547    for mig in status {
548        let ok = mig_ok(&mig);
549
550        table.add_row(Vec::from([
551            Cell::new(mig.version.to_string().as_str()).set_alignment(CellAlignment::Center),
552            Cell::new(&mig.name).set_alignment(CellAlignment::Center),
553            Cell::new(if mig.applied.is_some() { "x" } else { "" })
554                .set_alignment(CellAlignment::Center),
555            Cell::new(if ok { "x" } else { "INVALID" }).set_alignment(CellAlignment::Center),
556            Cell::new(if mig.reversible { "x" } else { "" }).set_alignment(CellAlignment::Center),
557        ]));
558    }
559
560    println!("{}", table);
561
562    if !all_valid {
563        process::exit(1);
564    }
565}
566
567fn print_summary(summary: &MigrationSummary) {
568    let mut table = Table::new();
569
570    table
571        .set_content_arrangement(ContentArrangement::Dynamic)
572        .set_header(Vec::from([
573            Cell::new("Old Version").set_alignment(CellAlignment::Center),
574            Cell::new("New Version").set_alignment(CellAlignment::Center),
575            Cell::new("Applied Migrations").set_alignment(CellAlignment::Center),
576            Cell::new("Reverted Migrations").set_alignment(CellAlignment::Center),
577        ]));
578
579    let mut s = Vec::<Cell>::new();
580
581    s.push(match summary.old_version {
582        Some(v) => Cell::new(v.to_string()).set_alignment(CellAlignment::Center),
583        None => "".into(),
584    });
585
586    s.push(match summary.new_version {
587        Some(v) => Cell::new(v.to_string()).set_alignment(CellAlignment::Center),
588        None => "".into(),
589    });
590
591    s.push(match (summary.old_version, summary.new_version) {
592        (Some(old), Some(new)) => {
593            if new >= old {
594                Cell::new((new - old).to_string()).set_alignment(CellAlignment::Center)
595            } else {
596                Cell::new("0").set_alignment(CellAlignment::Center)
597            }
598        }
599        (None, Some(new)) => Cell::new(new.to_string()).set_alignment(CellAlignment::Center),
600        (_, None) => Cell::new("0").set_alignment(CellAlignment::Center),
601    });
602
603    s.push(match (summary.old_version, summary.new_version) {
604        (Some(old), Some(new)) => {
605            if new <= old {
606                Cell::new((old - new).to_string()).set_alignment(CellAlignment::Center)
607            } else {
608                Cell::new("0").set_alignment(CellAlignment::Center)
609            }
610        }
611        (Some(old), None) => Cell::new(old.to_string()).set_alignment(CellAlignment::Center),
612        (None, _) => Cell::new("0").set_alignment(CellAlignment::Center),
613    });
614
615    table.add_row(s);
616
617    eprintln!("{table}");
618}
619
620async fn setup_migrator<Db>(migrate: &Migrate, migrations: Vec<Migration<Db>>) -> Migrator<Db>
621where
622    Db: Database,
623    Db::Connection: db::Migrations,
624    for<'a> &'a mut Db::Connection: Executor<'a>,
625{
626    let db_url = match &migrate.database_url {
627        Some(s) => s.clone(),
628        None => {
629            if let Ok(url) = std::env::var("DATABASE_URL") {
630                url
631            } else {
632                tracing::error!(
633                    "`DATABASE_URL` environment variable or `--database-url` argument is required"
634                );
635                process::exit(1);
636            }
637        }
638    };
639
640    let mut options =
641        match db_url.parse::<<<Db as Database>::Connection as sqlx::Connection>::Options>() {
642            Ok(opts) => opts,
643            Err(err) => {
644                tracing::error!(error = %err, "invalid database URL");
645                process::exit(1);
646            }
647        };
648
649    if migrate.log_statements {
650        options = options
651            .log_statements("INFO".parse().unwrap())
652            .log_slow_statements("WARN".parse().unwrap(), Duration::from_secs(1));
653    } else {
654        options = options.disable_statement_logging();
655    }
656
657    match Migrator::connect_with(&options).await {
658        Ok(mut mig) => {
659            mig.set_options(MigratorOptions {
660                verify_checksums: !migrate.no_verify_checksums,
661                verify_names: !migrate.no_verify_names,
662            });
663
664            if !migrate.migrations_table.is_empty() {
665                mig.set_migrations_table(&migrate.migrations_table);
666            }
667
668            mig.add_migrations(migrations);
669
670            mig
671        }
672        Err(err) => {
673            tracing::error!(error = %err, "failed to create database connection");
674            process::exit(1);
675        }
676    }
677}
678
679fn setup_logging(migrate: &Migrate) {
680    let format = tracing_subscriber::fmt::format().with_ansi(colors(migrate));
681
682    let verbose = migrate.verbose;
683
684    let span_events = if verbose {
685        FmtSpan::NEW | FmtSpan::CLOSE
686    } else {
687        FmtSpan::CLOSE
688    };
689
690    let registry = tracing_subscriber::registry();
691
692    let env_filter = match EnvFilter::try_from_default_env() {
693        Ok(f) => f,
694        Err(_) => EnvFilter::default()
695            .add_directive(tracing::Level::INFO.into())
696            .add_directive("sqlx::postgres::notice=error".parse().unwrap()),
697    };
698
699    if verbose {
700        registry
701            .with(env_filter)
702            .with(
703                tracing_subscriber::fmt::layer()
704                    .with_writer(io::stderr)
705                    .with_span_events(span_events)
706                    .event_format(format.pretty()),
707            )
708            .init();
709    } else {
710        registry
711            .with(env_filter)
712            .with(
713                tracing_subscriber::fmt::layer()
714                    .with_writer(io::stderr)
715                    .with_span_events(span_events)
716                    .event_format(format),
717            )
718            .init();
719    }
720}
721
722fn colors(matches: &Migrate) -> bool {
723    if matches.no_colors {
724        return false;
725    }
726
727    atty::is(atty::Stream::Stdout)
728}