1use std::time::Duration;
2use std::{io, sync::Once};
3
4use anyhow::Result;
5use futures_util::{Future, TryFutureExt};
6
7use sqlx::any::install_drivers;
8use sqlx::{AnyConnection, Connection};
9use tokio::{select, signal};
10
11use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand};
12
13mod database;
14mod migrate;
15mod opt;
16
17pub use crate::opt::Opt;
18
19pub fn maybe_apply_dotenv() {
21 if std::env::args().any(|arg| arg == "--no-dotenv") {
22 return;
23 }
24
25 dotenvy::dotenv().ok();
26}
27
28pub async fn run(opt: Opt) -> Result<()> {
29 let ctrlc_fut = signal::ctrl_c();
35 let do_run_fut = do_run(opt);
36
37 select! {
38 biased;
39 _ = ctrlc_fut => {
40 Ok(())
41 },
42 do_run_outcome = do_run_fut => {
43 do_run_outcome
44 }
45 }
46}
47
48async fn do_run(opt: Opt) -> Result<()> {
49 match opt.command {
50 Command::Migrate(migrate) => match migrate.command {
51 MigrateCommand::Add {
52 source,
53 description,
54 reversible,
55 sequential,
56 timestamp,
57 } => migrate::add(&source, &description, reversible, sequential, timestamp).await?,
58 MigrateCommand::Run {
59 source,
60 dry_run,
61 ignore_missing,
62 connect_opts,
63 target_version,
64 } => {
65 migrate::run(
66 &source,
67 &connect_opts,
68 dry_run,
69 *ignore_missing,
70 target_version,
71 )
72 .await?
73 }
74 MigrateCommand::Revert {
75 source,
76 dry_run,
77 ignore_missing,
78 connect_opts,
79 target_version,
80 } => {
81 migrate::revert(
82 &source,
83 &connect_opts,
84 dry_run,
85 *ignore_missing,
86 target_version,
87 )
88 .await?
89 }
90 MigrateCommand::Info {
91 source,
92 connect_opts,
93 } => migrate::info(&source, &connect_opts).await?,
94 },
95
96 Command::Database(database) => match database.command {
97 DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?,
98 DatabaseCommand::Drop {
99 confirmation,
100 connect_opts,
101 force,
102 } => database::drop(&connect_opts, !confirmation.yes, force).await?,
103 DatabaseCommand::Reset {
104 confirmation,
105 source,
106 connect_opts,
107 force,
108 } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?,
109 DatabaseCommand::Setup {
110 source,
111 connect_opts,
112 } => database::setup(&source, &connect_opts).await?,
113 },
114 };
115
116 Ok(())
117}
118
119async fn connect(opts: &ConnectOpts) -> anyhow::Result<AnyConnection> {
121 retry_connect_errors(opts, AnyConnection::connect).await
122}
123
124async fn retry_connect_errors<'a, F, Fut, T>(
129 opts: &'a ConnectOpts,
130 mut connect: F,
131) -> anyhow::Result<T>
132where
133 F: FnMut(&'a str) -> Fut,
134 Fut: Future<Output = sqlx::Result<T>> + 'a,
135{
136 install_default_drivers();
137
138 let db_url = opts.required_db_url()?;
139
140 backoff::future::retry(
141 backoff::ExponentialBackoffBuilder::new()
142 .with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout)))
143 .build(),
144 || {
145 connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
146 if let sqlx::Error::Io(ref ioe) = e {
147 match ioe.kind() {
148 io::ErrorKind::ConnectionRefused
149 | io::ErrorKind::ConnectionReset
150 | io::ErrorKind::ConnectionAborted => {
151 return backoff::Error::transient(e.into());
152 }
153 _ => (),
154 }
155 }
156
157 backoff::Error::permanent(e.into())
158 })
159 },
160 )
161 .await
162}
163
164pub fn install_default_drivers() {
174 static ONCE: Once = Once::new();
175
176 ONCE.call_once(|| {
177 install_drivers(&[sqlx_scylladb_core::any::DRIVER])
178 .expect("non-default drivers already installed")
179 });
180}