1use include_dir::Dir;
46use sqlx::postgres::PgRow;
47use sqlx::{Connection, Executor, PgConnection, Row};
48use thiserror::Error;
49
50#[derive(Error, Debug)]
52pub enum Error {
53 #[error("expected migration `{0}` to already have been run")]
54 MissingMigration(String),
55
56 #[error("invalid URL `{0}`: could not determine DB name")]
57 InvalidURL(String),
58
59 #[error("error connecting to existing database: {}", .source)]
60 ExistingConnectError { source: sqlx::Error },
61
62 #[error("error connecting to base URL `{}` to create DB: {}", .url, .source)]
63 BaseConnect { url: String, source: sqlx::Error },
64
65 #[error("error finding current migrations: {}", .source)]
66 CurrentMigrations { source: sqlx::Error },
67
68 #[error("invalid utf-8 bytes in migration content: {0}")]
69 InvalidMigrationContent(std::path::PathBuf),
70
71 #[error("invalid utf-8 bytes in migration path: {0}")]
72 InvalidMigrationPath(std::path::PathBuf),
73
74 #[error("more migrations run than are known indicating possibly deleted migrations")]
75 DeletedMigrations,
76
77 #[error(transparent)]
78 DB(#[from] sqlx::Error),
79}
80
81type Result<T> = std::result::Result<T, Error>;
82
83fn base_and_db(url: &str) -> Result<(&str, &str)> {
84 let base_split: Vec<&str> = url.rsplitn(2, '/').collect();
85 if base_split.len() != 2 {
86 return Err(Error::InvalidURL(url.to_string()));
87 }
88 let qmark_split: Vec<&str> = base_split[0].splitn(2, '?').collect();
89 Ok((base_split[1], qmark_split[0]))
90}
91
92async fn maybe_make_db(url: &str) -> Result<()> {
93 match PgConnection::connect(url).await {
94 Ok(_) => return Ok(()), Err(err) => {
96 if let sqlx::Error::Database(dberr) = err {
97 if let Some("3D000") = dberr.code().as_deref() {
99 Ok(()) } else {
101 Err(Error::ExistingConnectError {
102 source: sqlx::Error::Database(dberr),
103 })
104 }
105 } else {
106 Err(Error::ExistingConnectError { source: err })
107 }
108 }
109 }?;
110
111 let (base_url, db_name) = base_and_db(url)?;
112 let mut db = match PgConnection::connect(&format!("{}/postgres", base_url)).await {
113 Ok(db) => db,
114 Err(err) => {
115 return Err(Error::BaseConnect {
116 url: base_url.to_string(),
117 source: err,
118 })
119 }
120 };
121 sqlx::query(&format!(r#"CREATE DATABASE "{}""#, db_name))
122 .execute(&mut db)
123 .await?;
124 Ok(())
125}
126
127async fn get_migrated(db: &mut PgConnection) -> Result<Vec<String>> {
128 let migrated = sqlx::query("SELECT migration FROM sqlx_pg_migrate ORDER BY id")
129 .try_map(|row: PgRow| row.try_get("migration"))
130 .fetch_all(db)
131 .await;
132 match migrated {
133 Ok(migrated) => Ok(migrated),
134 Err(err) => {
135 if let sqlx::Error::Database(dberr) = err {
136 if let Some("42P01") = dberr.code().as_deref() {
138 Ok(vec![])
139 } else {
140 Err(Error::CurrentMigrations {
141 source: sqlx::Error::Database(dberr),
142 })
143 }
144 } else {
145 Err(Error::CurrentMigrations { source: err })
146 }
147 }
148 }
149}
150
151pub async fn migrate(url: &str, dir: &Dir<'_>) -> Result<()> {
154 maybe_make_db(url).await?;
155 let mut db = PgConnection::connect(url).await?;
156 let migrated = get_migrated(&mut db).await?;
157 let mut tx = db.begin().await?;
158 if migrated.is_empty() {
159 sqlx::query(
160 r#"
161 CREATE TABLE IF NOT EXISTS sqlx_pg_migrate (
162 id SERIAL PRIMARY KEY,
163 migration TEXT UNIQUE,
164 created TIMESTAMP NOT NULL DEFAULT current_timestamp
165 );
166 "#,
167 )
168 .execute(&mut tx)
169 .await?;
170 }
171 let mut files: Vec<_> = dir.files().collect();
172 if migrated.len() > files.len() {
173 return Err(Error::DeletedMigrations);
174 }
175 files.sort_by(|a, b| a.path().partial_cmp(b.path()).unwrap());
176 for (pos, f) in files.iter().enumerate() {
177 let path = f
178 .path()
179 .to_str()
180 .ok_or_else(|| Error::InvalidMigrationPath(f.path().to_owned()))?;
181
182 if pos < migrated.len() {
183 if migrated[pos] != path {
184 return Err(Error::MissingMigration(path.to_owned()));
185 }
186 continue;
187 }
188
189 let content = f
190 .contents_utf8()
191 .ok_or_else(|| Error::InvalidMigrationContent(f.path().to_owned()))?;
192 tx.execute(content).await?;
193 sqlx::query("INSERT INTO sqlx_pg_migrate (migration) VALUES ($1)")
194 .bind(path)
195 .execute(&mut tx)
196 .await?;
197 }
198 tx.commit().await?;
199 Ok(())
200}
201
202#[cfg(test)]
203mod tests {
204 use super::migrate;
205 use include_dir::{include_dir, Dir};
206
207 static MIGRATIONS: Dir = include_dir!("migrations");
208
209 #[async_attributes::test]
210 async fn it_works() -> std::result::Result<(), super::Error> {
211 let url = std::env::var("DATABASE_URL").unwrap_or(String::from(
212 "postgresql://localhost/sqlxpgmigrate1?sslmode=disable",
213 ));
214 for _ in 0..2 {
216 match migrate(&url, &MIGRATIONS).await {
217 Err(err) => panic!("migrate failed with: {}", err),
218 _ => (),
219 };
220 }
221 Ok(())
222 }
223}