sqlx_scylladb_cli/
lib.rs

1use std::time::Duration;
2use std::{io, sync::Once};
3
4use anyhow::Result;
5use futures_util::{Future, TryFutureExt};
6
7use sqlx::any::install_drivers;
8use sqlx::{AnyConnection, Connection};
9use tokio::{select, signal};
10
11use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand};
12
13mod database;
14mod migrate;
15mod opt;
16
17pub use crate::opt::Opt;
18
19/// Check arguments for `--no-dotenv` _before_ Clap parsing, and apply `.env` if not set.
20pub fn maybe_apply_dotenv() {
21    if std::env::args().any(|arg| arg == "--no-dotenv") {
22        return;
23    }
24
25    dotenvy::dotenv().ok();
26}
27
28pub async fn run(opt: Opt) -> Result<()> {
29    // This `select!` is here so that when the process receives a `SIGINT` (CTRL + C),
30    // the futures currently running on this task get dropped before the program exits.
31    // This is currently necessary for the consumers of the `dialoguer` crate to restore
32    // the user's terminal if the process is interrupted while a dialog is being displayed.
33
34    let ctrlc_fut = signal::ctrl_c();
35    let do_run_fut = do_run(opt);
36
37    select! {
38        biased;
39        _ = ctrlc_fut => {
40            Ok(())
41        },
42        do_run_outcome = do_run_fut => {
43            do_run_outcome
44        }
45    }
46}
47
48async fn do_run(opt: Opt) -> Result<()> {
49    match opt.command {
50        Command::Migrate(migrate) => match migrate.command {
51            MigrateCommand::Add {
52                source,
53                description,
54                reversible,
55                sequential,
56                timestamp,
57            } => migrate::add(&source, &description, reversible, sequential, timestamp).await?,
58            MigrateCommand::Run {
59                source,
60                dry_run,
61                ignore_missing,
62                connect_opts,
63                target_version,
64            } => {
65                migrate::run(
66                    &source,
67                    &connect_opts,
68                    dry_run,
69                    *ignore_missing,
70                    target_version,
71                )
72                .await?
73            }
74            MigrateCommand::Revert {
75                source,
76                dry_run,
77                ignore_missing,
78                connect_opts,
79                target_version,
80            } => {
81                migrate::revert(
82                    &source,
83                    &connect_opts,
84                    dry_run,
85                    *ignore_missing,
86                    target_version,
87                )
88                .await?
89            }
90            MigrateCommand::Info {
91                source,
92                connect_opts,
93            } => migrate::info(&source, &connect_opts).await?,
94        },
95
96        Command::Database(database) => match database.command {
97            DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?,
98            DatabaseCommand::Drop {
99                confirmation,
100                connect_opts,
101                force,
102            } => database::drop(&connect_opts, !confirmation.yes, force).await?,
103            DatabaseCommand::Reset {
104                confirmation,
105                source,
106                connect_opts,
107                force,
108            } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?,
109            DatabaseCommand::Setup {
110                source,
111                connect_opts,
112            } => database::setup(&source, &connect_opts).await?,
113        },
114    };
115
116    Ok(())
117}
118
119/// Attempt to connect to the database server, retrying up to `ops.connect_timeout`.
120async fn connect(opts: &ConnectOpts) -> anyhow::Result<AnyConnection> {
121    retry_connect_errors(opts, AnyConnection::connect).await
122}
123
124/// Attempt an operation that may return errors like `ConnectionRefused`,
125/// retrying up until `ops.connect_timeout`.
126///
127/// The closure is passed `&ops.database_url` for easy composition.
128async fn retry_connect_errors<'a, F, Fut, T>(
129    opts: &'a ConnectOpts,
130    mut connect: F,
131) -> anyhow::Result<T>
132where
133    F: FnMut(&'a str) -> Fut,
134    Fut: Future<Output = sqlx::Result<T>> + 'a,
135{
136    install_default_drivers();
137
138    let db_url = opts.required_db_url()?;
139
140    backoff::future::retry(
141        backoff::ExponentialBackoffBuilder::new()
142            .with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout)))
143            .build(),
144        || {
145            connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
146                if let sqlx::Error::Io(ref ioe) = e {
147                    match ioe.kind() {
148                        io::ErrorKind::ConnectionRefused
149                        | io::ErrorKind::ConnectionReset
150                        | io::ErrorKind::ConnectionAborted => {
151                            return backoff::Error::transient(e.into());
152                        }
153                        _ => (),
154                    }
155                }
156
157                backoff::Error::permanent(e.into())
158            })
159        },
160    )
161    .await
162}
163
164/// Install all currently compiled-in drivers for [`AnyConnection`] to use.
165///
166/// May be called multiple times; only the first call will install drivers, subsequent calls
167/// will have no effect.
168///
169/// ### Panics
170/// If [`install_drivers`] has already been called *not* through this function.
171///
172/// [`AnyConnection`]: sqlx_core::any::AnyConnection
173pub fn install_default_drivers() {
174    static ONCE: Once = Once::new();
175
176    ONCE.call_once(|| {
177        install_drivers(&[sqlx_scylladb_core::any::DRIVER])
178            .expect("non-default drivers already installed")
179    });
180}