1use std::io;
2use std::time::Duration;
3
4use anyhow::Result;
5use futures::{Future, TryFutureExt};
6
7use sqlx::{AnyConnection, Connection};
8use tokio::{select, signal};
9
10use crate::opt::{Command, ConnectOpts, DatabaseCommand, MigrateCommand};
11
12mod database;
13mod metadata;
14#[cfg(feature = "completions")]
17mod completions;
18mod migrate;
19mod opt;
20mod prepare;
21
22pub use crate::opt::Opt;
23
24pub fn maybe_apply_dotenv() {
26 if std::env::args().any(|arg| arg == "--no-dotenv") {
27 return;
28 }
29
30 dotenvy::dotenv().ok();
31}
32
33pub async fn run(opt: Opt) -> Result<()> {
34 let ctrlc_fut = signal::ctrl_c();
40 let do_run_fut = do_run(opt);
41
42 select! {
43 biased;
44 _ = ctrlc_fut => {
45 Ok(())
46 },
47 do_run_outcome = do_run_fut => {
48 do_run_outcome
49 }
50 }
51}
52
53async fn do_run(opt: Opt) -> Result<()> {
54 match opt.command {
55 Command::Migrate(migrate) => match migrate.command {
56 MigrateCommand::Add {
57 source,
58 description,
59 reversible,
60 sequential,
61 timestamp,
62 } => migrate::add(&source, &description, reversible, sequential, timestamp).await?,
63 MigrateCommand::Run {
64 source,
65 dry_run,
66 ignore_missing,
67 connect_opts,
68 target_version,
69 } => {
70 migrate::run(
71 &source,
72 &connect_opts,
73 dry_run,
74 *ignore_missing,
75 target_version,
76 )
77 .await?
78 }
79 MigrateCommand::Revert {
80 source,
81 dry_run,
82 ignore_missing,
83 connect_opts,
84 target_version,
85 } => {
86 migrate::revert(
87 &source,
88 &connect_opts,
89 dry_run,
90 *ignore_missing,
91 target_version,
92 )
93 .await?
94 }
95 MigrateCommand::Info {
96 source,
97 connect_opts,
98 } => migrate::info(&source, &connect_opts).await?,
99 MigrateCommand::BuildScript { source, force } => migrate::build_script(&source, force)?,
100 },
101
102 Command::Database(database) => match database.command {
103 DatabaseCommand::Create { connect_opts } => database::create(&connect_opts).await?,
104 DatabaseCommand::Drop {
105 confirmation,
106 connect_opts,
107 force,
108 } => database::drop(&connect_opts, !confirmation.yes, force).await?,
109 DatabaseCommand::Reset {
110 confirmation,
111 source,
112 connect_opts,
113 force,
114 } => database::reset(&source, &connect_opts, !confirmation.yes, force).await?,
115 DatabaseCommand::Setup {
116 source,
117 connect_opts,
118 } => database::setup(&source, &connect_opts).await?,
119 },
120
121 Command::Prepare {
122 check,
123 all,
124 workspace,
125 connect_opts,
126 args,
127 } => prepare::run(check, all, workspace, connect_opts, args).await?,
128
129 #[cfg(feature = "completions")]
130 Command::Completions { shell } => completions::run(shell),
131 };
132
133 Ok(())
134}
135
136async fn connect(opts: &ConnectOpts) -> anyhow::Result<AnyConnection> {
138 retry_connect_errors(opts, AnyConnection::connect).await
139}
140
141async fn retry_connect_errors<'a, F, Fut, T>(
146 opts: &'a ConnectOpts,
147 mut connect: F,
148) -> anyhow::Result<T>
149where
150 F: FnMut(&'a str) -> Fut,
151 Fut: Future<Output = sqlx::Result<T>> + 'a,
152{
153 sqlx::any::install_default_drivers();
154
155 let db_url = opts.required_db_url()?;
156
157 backoff::future::retry(
158 backoff::ExponentialBackoffBuilder::new()
159 .with_max_elapsed_time(Some(Duration::from_secs(opts.connect_timeout)))
160 .build(),
161 || {
162 connect(db_url).map_err(|e| -> backoff::Error<anyhow::Error> {
163 if let sqlx::Error::Io(ref ioe) = e {
164 match ioe.kind() {
165 io::ErrorKind::ConnectionRefused
166 | io::ErrorKind::ConnectionReset
167 | io::ErrorKind::ConnectionAborted => {
168 return backoff::Error::transient(e.into());
169 }
170 _ => (),
171 }
172 }
173
174 backoff::Error::permanent(e.into())
175 })
176 },
177 )
178 .await
179}