sqlx_pg_migrate/
lib.rs

1//! A library to run migrations on a PostgreSQL database using SQLx.
2//!
3//! Make a directory that contains your migrations. The library will run thru
4//! all the files in sorted order. The suggested naming convention is
5//! `000_first.sql`, `001_second.sql` and so on.
6//!
7//! The library:
8//! 1. Will create the DB if necessary.
9//! 1. Will create a table named `sqlx_pg_migrate` to manage the migration state.
10//! 1. Will run everything in a single transaction, so all pending migrations
11//!    are run, or nothing.
12//! 1. Expects you to never delete or rename a migration.
13//! 1. Expects you to not put a new migration between two existing ones.
14//! 1. Expects file names and contents to be UTF-8.
15//! 1. There are no rollbacks - just write a new migration.
16//!
17//! You'll need to add these two crates as dependencies:
18//! ```toml
19//! [dependencies]
20//! include_dir = "0.6"
21//! sqlx-pg-migrate = "1.0"
22//! ```
23//!
24//! The usage looks like this:
25//!
26//! ```
27//! use sqlx_pg_migrate::migrate;
28//! use include_dir::{include_dir, Dir};
29//!
30//! // Use include_dir! to include your migrations into your binary.
31//! // The path here is relative to your cargo root.
32//! static MIGRATIONS: Dir = include_dir!("migrations");
33//!
34//! # #[async_attributes::main]
35//! # async fn main() -> std::result::Result<(), sqlx_pg_migrate::Error> {
36//! #    let db_url = std::env::var("DATABASE_URL")
37//! #        .unwrap_or(String::from("postgresql://localhost/sqlxpgmigrate_doctest"));
38//! // Somewhere, probably in main, call the migrate function with your DB URL
39//! // and the included migrations.
40//! migrate(&db_url, &MIGRATIONS).await?;
41//! #    Ok(())
42//! # }
43//! ```
44
45use include_dir::Dir;
46use sqlx::postgres::PgRow;
47use sqlx::{Connection, Executor, PgConnection, Row};
48use thiserror::Error;
49
50/// The various kinds of errors that can arise when running the migrations.
51#[derive(Error, Debug)]
52pub enum Error {
53    #[error("expected migration `{0}` to already have been run")]
54    MissingMigration(String),
55
56    #[error("invalid URL `{0}`: could not determine DB name")]
57    InvalidURL(String),
58
59    #[error("error connecting to existing database: {}", .source)]
60    ExistingConnectError { source: sqlx::Error },
61
62    #[error("error connecting to base URL `{}` to create DB: {}", .url, .source)]
63    BaseConnect { url: String, source: sqlx::Error },
64
65    #[error("error finding current migrations: {}", .source)]
66    CurrentMigrations { source: sqlx::Error },
67
68    #[error("invalid utf-8 bytes in migration content: {0}")]
69    InvalidMigrationContent(std::path::PathBuf),
70
71    #[error("invalid utf-8 bytes in migration path: {0}")]
72    InvalidMigrationPath(std::path::PathBuf),
73
74    #[error("more migrations run than are known indicating possibly deleted migrations")]
75    DeletedMigrations,
76
77    #[error(transparent)]
78    DB(#[from] sqlx::Error),
79}
80
81type Result<T> = std::result::Result<T, Error>;
82
83fn base_and_db(url: &str) -> Result<(&str, &str)> {
84    let base_split: Vec<&str> = url.rsplitn(2, '/').collect();
85    if base_split.len() != 2 {
86        return Err(Error::InvalidURL(url.to_string()));
87    }
88    let qmark_split: Vec<&str> = base_split[0].splitn(2, '?').collect();
89    Ok((base_split[1], qmark_split[0]))
90}
91
92async fn maybe_make_db(url: &str) -> Result<()> {
93    match PgConnection::connect(url).await {
94        Ok(_) => return Ok(()), // it exists, we're done
95        Err(err) => {
96            if let sqlx::Error::Database(dberr) = err {
97                // this indicates the database doesn't exist
98                if let Some("3D000") = dberr.code().as_deref() {
99                    Ok(()) // it doesn't exist, continue to create it
100                } else {
101                    Err(Error::ExistingConnectError {
102                        source: sqlx::Error::Database(dberr),
103                    })
104                }
105            } else {
106                Err(Error::ExistingConnectError { source: err })
107            }
108        }
109    }?;
110
111    let (base_url, db_name) = base_and_db(url)?;
112    let mut db = match PgConnection::connect(&format!("{}/postgres", base_url)).await {
113        Ok(db) => db,
114        Err(err) => {
115            return Err(Error::BaseConnect {
116                url: base_url.to_string(),
117                source: err,
118            })
119        }
120    };
121    sqlx::query(&format!(r#"CREATE DATABASE "{}""#, db_name))
122        .execute(&mut db)
123        .await?;
124    Ok(())
125}
126
127async fn get_migrated(db: &mut PgConnection) -> Result<Vec<String>> {
128    let migrated = sqlx::query("SELECT migration FROM sqlx_pg_migrate ORDER BY id")
129        .try_map(|row: PgRow| row.try_get("migration"))
130        .fetch_all(db)
131        .await;
132    match migrated {
133        Ok(migrated) => Ok(migrated),
134        Err(err) => {
135            if let sqlx::Error::Database(dberr) = err {
136                // this indicates the table doesn't exist
137                if let Some("42P01") = dberr.code().as_deref() {
138                    Ok(vec![])
139                } else {
140                    Err(Error::CurrentMigrations {
141                        source: sqlx::Error::Database(dberr),
142                    })
143                }
144            } else {
145                Err(Error::CurrentMigrations { source: err })
146            }
147        }
148    }
149}
150
151/// Runs the migrations contained in the directory. See module documentation for
152/// more information.
153pub async fn migrate(url: &str, dir: &Dir<'_>) -> Result<()> {
154    maybe_make_db(url).await?;
155    let mut db = PgConnection::connect(url).await?;
156    let migrated = get_migrated(&mut db).await?;
157    let mut tx = db.begin().await?;
158    if migrated.is_empty() {
159        sqlx::query(
160            r#"
161                CREATE TABLE IF NOT EXISTS sqlx_pg_migrate (
162                    id SERIAL PRIMARY KEY,
163                    migration TEXT UNIQUE,
164                    created TIMESTAMP NOT NULL DEFAULT current_timestamp
165                );
166            "#,
167        )
168        .execute(&mut tx)
169        .await?;
170    }
171    let mut files: Vec<_> = dir.files().collect();
172    if migrated.len() > files.len() {
173        return Err(Error::DeletedMigrations);
174    }
175    files.sort_by(|a, b| a.path().partial_cmp(b.path()).unwrap());
176    for (pos, f) in files.iter().enumerate() {
177        let path = f
178            .path()
179            .to_str()
180            .ok_or_else(|| Error::InvalidMigrationPath(f.path().to_owned()))?;
181
182        if pos < migrated.len() {
183            if migrated[pos] != path {
184                return Err(Error::MissingMigration(path.to_owned()));
185            }
186            continue;
187        }
188
189        let content = f
190            .contents_utf8()
191            .ok_or_else(|| Error::InvalidMigrationContent(f.path().to_owned()))?;
192        tx.execute(content).await?;
193        sqlx::query("INSERT INTO sqlx_pg_migrate (migration) VALUES ($1)")
194            .bind(path)
195            .execute(&mut tx)
196            .await?;
197    }
198    tx.commit().await?;
199    Ok(())
200}
201
202#[cfg(test)]
203mod tests {
204    use super::migrate;
205    use include_dir::{include_dir, Dir};
206
207    static MIGRATIONS: Dir = include_dir!("migrations");
208
209    #[async_attributes::test]
210    async fn it_works() -> std::result::Result<(), super::Error> {
211        let url = std::env::var("DATABASE_URL").unwrap_or(String::from(
212            "postgresql://localhost/sqlxpgmigrate1?sslmode=disable",
213        ));
214        // run it twice, second time should be a no-op
215        for _ in 0..2 {
216            match migrate(&url, &MIGRATIONS).await {
217                Err(err) => panic!("migrate failed with: {}", err),
218                _ => (),
219            };
220        }
221        Ok(())
222    }
223}