sqlx_migrate_validate/
validate.rs

1use std::{
2    borrow::Cow,
3    collections::{HashMap, HashSet},
4};
5
6use async_trait::async_trait;
7use sqlx::migrate::{
8    AppliedMigration, Migrate, MigrateError, Migration, MigrationSource, Migrator,
9};
10
11use crate::error::ValidateError;
12
13#[async_trait(?Send)]
14pub trait Validate {
15    /// Validate previously applied migrations against the migration source.
16    /// Depending on the migration source this can be used to check if all migrations
17    /// for the current version of the application have been applied.
18    /// Use [`Validator::from_migrator`] to use the migrations available during compilation.
19    ///
20    /// # Examples
21    ///
22    /// ```rust,no_run
23    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24    /// # sqlx_rt::block_on(async move {
25    /// # use sqlx_migrate_validate::Validator;
26    /// // Use migrations that were in a local folder during build: ./tests/migrations-1
27    /// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
28    ///
29    /// // Create a connection pool
30    /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
31    /// let mut conn = pool.acquire().await?;
32    ///
33    /// // Validate the migrations
34    /// v.validate(&mut *conn).await?;
35    /// # Ok(())
36    /// # })
37    /// # }
38    /// ```
39    async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
40    where
41        C: Migrate;
42}
43
44/// Validate previously applied migrations against the migration source.
45/// Depending on the migration source this can be used to check if all migrations
46/// for the current version of the application have been applied.
47/// Use [`Validator::from_migrator`] to use the migrations available during compilation.
48///
49/// # Examples
50///
51/// ```rust,no_run
52/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
53/// # sqlx_rt::block_on(async move {
54/// # use sqlx_migrate_validate::Validator;
55/// // Use migrations that were in a local folder during build: ./tests/migrations-1
56/// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
57///
58/// // Create a connection pool
59/// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
60/// let mut conn = pool.acquire().await?;
61///
62/// // Validate the migrations
63/// v.validate(&mut *conn).await?;
64/// # Ok(())
65/// # })
66/// # }
67/// ```
68#[derive(Debug)]
69pub struct Validator {
70    pub migrations: Cow<'static, [Migration]>,
71    pub ignore_missing: bool,
72    pub locking: bool,
73}
74
75impl Validator {
76    /// Creates a new instance with the given source. Please note that the source
77    /// is resolved at runtime and not at compile time.
78    /// You can use [`Validator::from<sqlx::Migrator>`] and the [`sqlx::migrate!`] macro
79    /// to embed the migrations into the binary during compile time.
80    ///
81    /// # Examples
82    ///
83    /// ```rust,no_run
84    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
85    /// # sqlx_rt::block_on(async move {
86    /// # use sqlx_migrate_validate::Validator;
87    /// use std::path::Path;
88    ///
89    /// // Read migrations from a local folder: ./tests/migrations-1
90    /// let v = Validator::new(Path::new("./tests/migrations-1")).await?;
91    ///
92    /// // Create a connection pool
93    /// let pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?;
94    /// let mut conn = pool.acquire().await?;
95    ///
96    /// // Validate the migrations
97    /// v.validate(&mut *conn).await?;
98    /// # Ok(())
99    /// # })
100    /// # }
101    /// ```
102    ///
103    /// See [MigrationSource] for details on structure of the `./tests/migrations-1` directory.
104    pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
105    where
106        S: MigrationSource<'s>,
107    {
108        Ok(Self {
109            migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
110            ignore_missing: false,
111            locking: true,
112        })
113    }
114
115    /// Creates a new instance with the migrations from the given migrator.
116    /// You can combine this with the [`sqlx::migrate!`] macro
117    /// to embed the migrations into the binary during compile time.
118    ///
119    /// # Examples
120    ///
121    /// ```rust,no_run
122    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
123    /// # sqlx_rt::block_on(async move {
124    /// # use sqlx_migrate_validate::Validator;
125    /// // Use migrations that were in a local folder during build: ./tests/migrations-1
126    /// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
127    ///
128    /// // Create a connection pool
129    /// let pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?;
130    /// let mut conn = pool.acquire().await?;
131    ///
132    /// // Validate the migrations
133    /// v.validate(&mut *conn).await?;
134    /// # Ok(())
135    /// # })
136    /// # }
137    /// ```
138    pub fn from_migrator(migrator: Migrator) -> Self {
139        Self {
140            migrations: migrator.migrations.clone(),
141            ignore_missing: migrator.ignore_missing,
142            locking: migrator.locking,
143        }
144    }
145
146    pub async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
147    where
148        C: Migrate,
149    {
150        // lock the migrator to prevent other migrators from running
151        if self.locking {
152            conn.lock().await?;
153        }
154
155        let version = conn.dirty_version().await?;
156        if let Some(version) = version {
157            return Err(ValidateError::MigrateError(MigrateError::Dirty(version)));
158        }
159
160        let applied_migrations = conn.list_applied_migrations().await?;
161        validate_applied_migrations(&applied_migrations, self)?;
162
163        let applied_migrations: HashMap<_, _> = applied_migrations
164            .into_iter()
165            .map(|m| (m.version, m))
166            .collect();
167
168        for migration in self.migrations.iter() {
169            if migration.migration_type.is_down_migration() {
170                continue;
171            }
172
173            match applied_migrations.get(&migration.version) {
174                Some(applied_migration) => {
175                    if migration.checksum != applied_migration.checksum {
176                        return Err(ValidateError::MigrateError(MigrateError::VersionMismatch(
177                            migration.version,
178                        )));
179                    }
180                }
181                None => {
182                    return Err(ValidateError::VersionNotApplied(migration.version));
183                    // conn.apply(migration).await?;
184                }
185            }
186        }
187
188        // unlock the migrator to allow other migrators to run
189        // but do nothing as we already migrated
190        if self.locking {
191            conn.unlock().await?;
192        }
193
194        Ok(())
195    }
196}
197
198impl From<&Migrator> for Validator {
199    fn from(migrator: &Migrator) -> Self {
200        Self {
201            migrations: migrator.migrations.clone(),
202            ignore_missing: migrator.ignore_missing,
203            locking: migrator.locking,
204        }
205    }
206}
207
208impl From<Migrator> for Validator {
209    fn from(migrator: Migrator) -> Self {
210        Self::from(&migrator)
211    }
212}
213
214#[async_trait(?Send)]
215impl Validate for Migrator {
216    async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
217    where
218        C: Migrate,
219    {
220        Validator::from(self).validate(conn).await
221    }
222}
223
224fn validate_applied_migrations(
225    applied_migrations: &[AppliedMigration],
226    migrator: &Validator,
227) -> Result<(), MigrateError> {
228    if migrator.ignore_missing {
229        return Ok(());
230    }
231
232    let migrations: HashSet<_> = migrator.migrations.iter().map(|m| m.version).collect();
233
234    for applied_migration in applied_migrations {
235        if !migrations.contains(&applied_migration.version) {
236            return Err(MigrateError::VersionMissing(applied_migration.version));
237        }
238    }
239
240    Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245    use sqlx::migrate::MigrationType;
246
247    use super::*;
248
249    #[test]
250    fn validate_applied_migrations_returns_ok_when_nothing_was_applied() {
251        let applied_migrations = vec![];
252        let mut validator = Validator {
253            migrations: Cow::Owned(vec![]),
254            ignore_missing: false,
255            locking: true,
256        };
257
258        assert!(validate_applied_migrations(&applied_migrations, &validator).is_ok());
259
260        validator.ignore_missing = true;
261        assert!(validate_applied_migrations(&applied_migrations, &validator).is_ok());
262    }
263
264    #[test]
265    fn validate_applied_migrations_returns_err_when_applied_migration_not_in_source() {
266        let applied_migrations = vec![AppliedMigration {
267            version: 1,
268
269            // only the version is relevant for this method
270            checksum: Cow::Owned(vec![]),
271        }];
272        let validator = Validator {
273            migrations: Cow::Owned(vec![]),
274            ignore_missing: false,
275            locking: true,
276        };
277
278        match validate_applied_migrations(&applied_migrations, &validator) {
279            Err(MigrateError::VersionMissing(i)) => assert_eq!(i, 1),
280            _ => panic!("Unexpected error"),
281        }
282    }
283
284    #[test]
285    fn validate_applied_migrations_returns_ok_when_applied_migration_in_source() {
286        let applied_migrations = vec![AppliedMigration {
287            version: 1,
288
289            // only the version is relevant for this method
290            checksum: Cow::Owned(vec![]),
291        }];
292        let validator = Validator {
293            migrations: Cow::Owned(vec![Migration {
294                version: 1,
295
296                // only the version is relevant for this method
297                migration_type: MigrationType::ReversibleUp,
298                checksum: Cow::Owned(vec![]),
299                sql: Cow::Owned("".to_string()),
300                description: Cow::Owned("".to_string()),
301            }]),
302            ignore_missing: false,
303            locking: true,
304        };
305
306        match validate_applied_migrations(&applied_migrations, &validator) {
307            Ok(_) => {}
308            _ => panic!("Unexpected error"),
309        }
310    }
311}