sea_orm_migration/
cli.rs

1use std::future::Future;
2
3use clap::Parser;
4use dotenvy::dotenv;
5use std::{error::Error, fmt::Display, process::exit};
6use tracing_subscriber::{prelude::*, EnvFilter};
7
8use sea_orm::{ConnectOptions, Database, DbConn, DbErr};
9use sea_orm_cli::{run_migrate_generate, run_migrate_init, MigrateSubcommands};
10
11use super::MigratorTrait;
12
13const MIGRATION_DIR: &str = "./";
14
15pub async fn run_cli<M>(migrator: M)
16where
17    M: MigratorTrait,
18{
19    run_cli_with_connection(migrator, Database::connect).await;
20}
21
22/// Same as [`run_cli`] where you provide the function to create the [`DbConn`].
23///
24/// This allows configuring the database connection as you see fit.
25/// E.g. you can change settings in [`ConnectOptions`] or you can load sqlite
26/// extensions.
27pub async fn run_cli_with_connection<M, F, Fut>(migrator: M, make_connection: F)
28where
29    M: MigratorTrait,
30    F: FnOnce(ConnectOptions) -> Fut,
31    Fut: Future<Output = Result<DbConn, DbErr>>,
32{
33    dotenv().ok();
34    let cli = Cli::parse();
35
36    let url = cli
37        .database_url
38        .expect("Environment variable 'DATABASE_URL' not set");
39    let schema = cli.database_schema.unwrap_or_else(|| "public".to_owned());
40
41    let connect_options = ConnectOptions::new(url)
42        .set_schema_search_path(schema)
43        .to_owned();
44
45    let db = make_connection(connect_options)
46        .await
47        .expect("Fail to acquire database connection");
48
49    run_migrate(migrator, &db, cli.command, cli.verbose)
50        .await
51        .unwrap_or_else(handle_error);
52}
53
54pub async fn run_migrate<M>(
55    _: M,
56    db: &DbConn,
57    command: Option<MigrateSubcommands>,
58    verbose: bool,
59) -> Result<(), Box<dyn Error>>
60where
61    M: MigratorTrait,
62{
63    let filter = match verbose {
64        true => "debug",
65        false => "sea_orm_migration=info",
66    };
67
68    let filter_layer = EnvFilter::try_new(filter).unwrap();
69
70    if verbose {
71        let fmt_layer = tracing_subscriber::fmt::layer();
72        tracing_subscriber::registry()
73            .with(filter_layer)
74            .with(fmt_layer)
75            .init()
76    } else {
77        let fmt_layer = tracing_subscriber::fmt::layer()
78            .with_target(false)
79            .with_level(false)
80            .without_time();
81        tracing_subscriber::registry()
82            .with(filter_layer)
83            .with(fmt_layer)
84            .init()
85    };
86
87    match command {
88        Some(MigrateSubcommands::Fresh) => M::fresh(db).await?,
89        Some(MigrateSubcommands::Refresh) => M::refresh(db).await?,
90        Some(MigrateSubcommands::Reset) => M::reset(db).await?,
91        Some(MigrateSubcommands::Status) => M::status(db).await?,
92        Some(MigrateSubcommands::Up { num }) => M::up(db, num).await?,
93        Some(MigrateSubcommands::Down { num }) => M::down(db, Some(num)).await?,
94        Some(MigrateSubcommands::Init) => run_migrate_init(MIGRATION_DIR)?,
95        Some(MigrateSubcommands::Generate {
96            migration_name,
97            universal_time: _,
98            local_time,
99        }) => run_migrate_generate(MIGRATION_DIR, &migration_name, !local_time)?,
100        _ => M::up(db, None).await?,
101    };
102
103    Ok(())
104}
105
106#[derive(Parser)]
107#[command(version)]
108pub struct Cli {
109    #[arg(short = 'v', long, global = true, help = "Show debug messages")]
110    verbose: bool,
111
112    #[arg(
113        global = true,
114        short = 's',
115        long,
116        env = "DATABASE_SCHEMA",
117        long_help = "Database schema\n \
118                    - For MySQL and SQLite, this argument is ignored.\n \
119                    - For PostgreSQL, this argument is optional with default value 'public'.\n"
120    )]
121    database_schema: Option<String>,
122
123    #[arg(
124        global = true,
125        short = 'u',
126        long,
127        env = "DATABASE_URL",
128        help = "Database URL"
129    )]
130    database_url: Option<String>,
131
132    #[command(subcommand)]
133    command: Option<MigrateSubcommands>,
134}
135
136fn handle_error<E>(error: E)
137where
138    E: Display,
139{
140    eprintln!("{error}");
141    exit(1);
142}