sqlx_cli/
lib.rs

1use std::io;
2use std::time::Duration;
3
4use anyhow::Result;
5use futures::{Future, TryFutureExt};
6
7use sqlx::{AnyConnection, Connection};
8
9use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand};
10
11mod database;
12mod metadata;
13// mod migration;
14// mod migrator;
15#[cfg(feature = "completions")]
16mod completions;
17mod migrate;
18mod opt;
19mod prepare;
20
21pub use crate::opt::Opt;
22
23pub async fn run(opt: Opt) -> Result<()> {
24    match opt.command {
25        Command::Migrate(migrate) => match migrate.command {
26            MigrateCommand::Add {
27                source,
28                description,
29                reversible,
30                sequential,
31                timestamp,
32            } => migrate::add(&source, &description, reversible, sequential, timestamp).await?,
33            MigrateCommand::Run {
34                source,
35                dry_run,
36                ignore_missing,
37                connect_opts,
38                target_version,
39            } => {
40                migrate::run(
41                    &source,
42                    &connect_opts,
43                    dry_run,
44                    *ignore_missing,
45                    target_version,
46                )
47                .await?
48            }
49            MigrateCommand::Revert {
50                source,
51                dry_run,
52                ignore_missing,
53                connect_opts,
54                target_version,
55            } => {
56                migrate::revert(
57                    &source,
58                    &connect_opts,
59                    dry_run,
60                    *ignore_missing,
61                    target_version,
62                )
63                .await?
64            }
65            MigrateCommand::Info {
66                source,
67                connect_opts,
68            } => migrate::info(&source, &connect_opts).await?,
69            MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?,
70        },
71
72        Command::Database(database) => match database.command {
73            DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?,
74            DatabaseCommand::Drop {
75                confirmation,
76                connect_opts,
77                force,
78            } => database::drop(&connect_opts, !confirmation.yes, force).await?,
79            DatabaseCommand::Reset {
80                confirmation,
81                source,
82                connect_opts,
83                force,
84            } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?,
85            DatabaseCommand::Setup {
86                source,
87                connect_opts,
88            } => database::setup(&source, &connect_opts).await?,
89        },
90
91        Command::Prepare {
92            check,
93            all,
94            workspace,
95            connect_opts,
96            args,
97        } => prepare::run(check, all, workspace, connect_opts, args).await?,
98
99        #[cfg(feature = "completions")]
100        Command::Completions { shell } => completions::run(shell),
101    };
102
103    Ok(())
104}
105
106/// Attempt to connect to the database server, retrying up to `ops.connect_timeout`.
107async fn connect(opts: &ConnectOpts) -> anyhow::Result<AnyConnection> {
108    retry_connect_errors(opts, AnyConnection::connect).await
109}
110
111/// Attempt an operation that may return errors like `ConnectionRefused`,
112/// retrying up until `ops.connect_timeout`.
113///
114/// The closure is passed `&ops.database_url` for easy composition.
115async fn retry_connect_errors<'a, F, Fut, T>(
116    opts: &'a ConnectOpts,
117    mut connect: F,
118) -> anyhow::Result<T>
119where
120    F: FnMut(&'a str) -> Fut,
121    Fut: Future<Output = sqlx::Result<T>> + 'a,
122{
123    sqlx::any::install_default_drivers();
124
125    let db_url = opts.required_db_url()?;
126
127    backoff::future::retry(
128        backoff::ExponentialBackoffBuilder::new()
129            .with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout)))
130            .build(),
131        || {
132            connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
133                if let sqlx::Error::Io(ref ioe) = e {
134                    match ioe.kind() {
135                        io::ErrorKind::ConnectionRefused
136                        | io::ErrorKind::ConnectionReset
137                        | io::ErrorKind::ConnectionAborted => {
138                            return backoff::Error::transient(e.into());
139                        }
140                        _ => (),
141                    }
142                }
143
144                backoff::Error::permanent(e.into())
145            })
146        },
147    )
148    .await
149}