sqlx_migrate_validate/validate.rs
1use std::{
2 borrow::Cow,
3 collections::{HashMap, HashSet},
4};
5
6use async_trait::async_trait;
7use sqlx::migrate::{
8 AppliedMigration, Migrate, MigrateError, Migration, MigrationSource, Migrator,
9};
10
11use crate::error::ValidateError;
12
13#[async_trait(?Send)]
14pub trait Validate {
15 /// Validate previously applied migrations against the migration source.
16 /// Depending on the migration source this can be used to check if all migrations
17 /// for the current version of the application have been applied.
18 /// Use [`Validator::from_migrator`] to use the migrations available during compilation.
19 ///
20 /// # Examples
21 ///
22 /// ```rust,no_run
23 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
24 /// # sqlx_rt::block_on(async move {
25 /// # use sqlx_migrate_validate::Validator;
26 /// // Use migrations that were in a local folder during build: ./tests/migrations-1
27 /// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
28 ///
29 /// // Create a connection pool
30 /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
31 /// let mut conn = pool.acquire().await?;
32 ///
33 /// // Validate the migrations
34 /// v.validate(&mut *conn).await?;
35 /// # Ok(())
36 /// # })
37 /// # }
38 /// ```
39 async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
40 where
41 C: Migrate;
42}
43
44/// Validate previously applied migrations against the migration source.
45/// Depending on the migration source this can be used to check if all migrations
46/// for the current version of the application have been applied.
47/// Use [`Validator::from_migrator`] to use the migrations available during compilation.
48///
49/// # Examples
50///
51/// ```rust,no_run
52/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
53/// # sqlx_rt::block_on(async move {
54/// # use sqlx_migrate_validate::Validator;
55/// // Use migrations that were in a local folder during build: ./tests/migrations-1
56/// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
57///
58/// // Create a connection pool
59/// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
60/// let mut conn = pool.acquire().await?;
61///
62/// // Validate the migrations
63/// v.validate(&mut *conn).await?;
64/// # Ok(())
65/// # })
66/// # }
67/// ```
68#[derive(Debug)]
69pub struct Validator {
70 pub migrations: Cow<'static, [Migration]>,
71 pub ignore_missing: bool,
72 pub locking: bool,
73}
74
75impl Validator {
76 /// Creates a new instance with the given source. Please note that the source
77 /// is resolved at runtime and not at compile time.
78 /// You can use [`Validator::from<sqlx::Migrator>`] and the [`sqlx::migrate!`] macro
79 /// to embed the migrations into the binary during compile time.
80 ///
81 /// # Examples
82 ///
83 /// ```rust,no_run
84 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
85 /// # sqlx_rt::block_on(async move {
86 /// # use sqlx_migrate_validate::Validator;
87 /// use std::path::Path;
88 ///
89 /// // Read migrations from a local folder: ./tests/migrations-1
90 /// let v = Validator::new(Path::new("./tests/migrations-1")).await?;
91 ///
92 /// // Create a connection pool
93 /// let pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?;
94 /// let mut conn = pool.acquire().await?;
95 ///
96 /// // Validate the migrations
97 /// v.validate(&mut *conn).await?;
98 /// # Ok(())
99 /// # })
100 /// # }
101 /// ```
102 ///
103 /// See [MigrationSource] for details on structure of the `./tests/migrations-1` directory.
104 pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
105 where
106 S: MigrationSource<'s>,
107 {
108 Ok(Self {
109 migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
110 ignore_missing: false,
111 locking: true,
112 })
113 }
114
115 /// Creates a new instance with the migrations from the given migrator.
116 /// You can combine this with the [`sqlx::migrate!`] macro
117 /// to embed the migrations into the binary during compile time.
118 ///
119 /// # Examples
120 ///
121 /// ```rust,no_run
122 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
123 /// # sqlx_rt::block_on(async move {
124 /// # use sqlx_migrate_validate::Validator;
125 /// // Use migrations that were in a local folder during build: ./tests/migrations-1
126 /// let v = Validator::from_migrator(sqlx::migrate!("./tests/migrations-1"));
127 ///
128 /// // Create a connection pool
129 /// let pool = sqlx::PgPool::connect("postgres://postgres:postgres@localhost:5432/postgres").await?;
130 /// let mut conn = pool.acquire().await?;
131 ///
132 /// // Validate the migrations
133 /// v.validate(&mut *conn).await?;
134 /// # Ok(())
135 /// # })
136 /// # }
137 /// ```
138 pub fn from_migrator(migrator: Migrator) -> Self {
139 Self {
140 migrations: migrator.migrations.clone(),
141 ignore_missing: migrator.ignore_missing,
142 locking: migrator.locking,
143 }
144 }
145
146 pub async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
147 where
148 C: Migrate,
149 {
150 // lock the migrator to prevent other migrators from running
151 if self.locking {
152 conn.lock().await?;
153 }
154
155 let version = conn.dirty_version().await?;
156 if let Some(version) = version {
157 return Err(ValidateError::MigrateError(MigrateError::Dirty(version)));
158 }
159
160 let applied_migrations = conn.list_applied_migrations().await?;
161 validate_applied_migrations(&applied_migrations, self)?;
162
163 let applied_migrations: HashMap<_, _> = applied_migrations
164 .into_iter()
165 .map(|m| (m.version, m))
166 .collect();
167
168 for migration in self.migrations.iter() {
169 if migration.migration_type.is_down_migration() {
170 continue;
171 }
172
173 match applied_migrations.get(&migration.version) {
174 Some(applied_migration) => {
175 if migration.checksum != applied_migration.checksum {
176 return Err(ValidateError::MigrateError(MigrateError::VersionMismatch(
177 migration.version,
178 )));
179 }
180 }
181 None => {
182 return Err(ValidateError::VersionNotApplied(migration.version));
183 // conn.apply(migration).await?;
184 }
185 }
186 }
187
188 // unlock the migrator to allow other migrators to run
189 // but do nothing as we already migrated
190 if self.locking {
191 conn.unlock().await?;
192 }
193
194 Ok(())
195 }
196}
197
198impl From<&Migrator> for Validator {
199 fn from(migrator: &Migrator) -> Self {
200 Self {
201 migrations: migrator.migrations.clone(),
202 ignore_missing: migrator.ignore_missing,
203 locking: migrator.locking,
204 }
205 }
206}
207
208impl From<Migrator> for Validator {
209 fn from(migrator: Migrator) -> Self {
210 Self::from(&migrator)
211 }
212}
213
214#[async_trait(?Send)]
215impl Validate for Migrator {
216 async fn validate<'c, C>(&self, conn: &mut C) -> Result<(), ValidateError>
217 where
218 C: Migrate,
219 {
220 Validator::from(self).validate(conn).await
221 }
222}
223
224fn validate_applied_migrations(
225 applied_migrations: &[AppliedMigration],
226 migrator: &Validator,
227) -> Result<(), MigrateError> {
228 if migrator.ignore_missing {
229 return Ok(());
230 }
231
232 let migrations: HashSet<_> = migrator.migrations.iter().map(|m| m.version).collect();
233
234 for applied_migration in applied_migrations {
235 if !migrations.contains(&applied_migration.version) {
236 return Err(MigrateError::VersionMissing(applied_migration.version));
237 }
238 }
239
240 Ok(())
241}
242
243#[cfg(test)]
244mod tests {
245 use sqlx::migrate::MigrationType;
246
247 use super::*;
248
249 #[test]
250 fn validate_applied_migrations_returns_ok_when_nothing_was_applied() {
251 let applied_migrations = vec![];
252 let mut validator = Validator {
253 migrations: Cow::Owned(vec![]),
254 ignore_missing: false,
255 locking: true,
256 };
257
258 assert!(validate_applied_migrations(&applied_migrations, &validator).is_ok());
259
260 validator.ignore_missing = true;
261 assert!(validate_applied_migrations(&applied_migrations, &validator).is_ok());
262 }
263
264 #[test]
265 fn validate_applied_migrations_returns_err_when_applied_migration_not_in_source() {
266 let applied_migrations = vec![AppliedMigration {
267 version: 1,
268
269 // only the version is relevant for this method
270 checksum: Cow::Owned(vec![]),
271 }];
272 let validator = Validator {
273 migrations: Cow::Owned(vec![]),
274 ignore_missing: false,
275 locking: true,
276 };
277
278 match validate_applied_migrations(&applied_migrations, &validator) {
279 Err(MigrateError::VersionMissing(i)) => assert_eq!(i, 1),
280 _ => panic!("Unexpected error"),
281 }
282 }
283
284 #[test]
285 fn validate_applied_migrations_returns_ok_when_applied_migration_in_source() {
286 let applied_migrations = vec![AppliedMigration {
287 version: 1,
288
289 // only the version is relevant for this method
290 checksum: Cow::Owned(vec![]),
291 }];
292 let validator = Validator {
293 migrations: Cow::Owned(vec![Migration {
294 version: 1,
295
296 // only the version is relevant for this method
297 migration_type: MigrationType::ReversibleUp,
298 checksum: Cow::Owned(vec![]),
299 sql: Cow::Owned("".to_string()),
300 description: Cow::Owned("".to_string()),
301 }]),
302 ignore_missing: false,
303 locking: true,
304 };
305
306 match validate_applied_migrations(&applied_migrations, &validator) {
307 Ok(_) => {}
308 _ => panic!("Unexpected error"),
309 }
310 }
311}