sqlx_core_oldapi/migrate/migrator.rs
1use crate::acquire::Acquire;
2use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource};
3use std::borrow::Cow;
4use std::collections::{HashMap, HashSet};
5use std::ops::Deref;
6use std::slice;
7
8#[derive(Debug)]
9pub struct Migrator {
10 pub migrations: Cow<'static, [Migration]>,
11 pub ignore_missing: bool,
12 pub locking: bool,
13}
14
15fn validate_applied_migrations(
16 applied_migrations: &[AppliedMigration],
17 migrator: &Migrator,
18) -> Result<(), MigrateError> {
19 if migrator.ignore_missing {
20 return Ok(());
21 }
22
23 let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
24
25 for applied_migration in applied_migrations {
26 if !migrations.contains(&applied_migration.version) {
27 return Err(MigrateError::VersionMissing(applied_migration.version));
28 }
29 }
30
31 Ok(())
32}
33
34impl Migrator {
35 /// Creates a new instance with the given source.
36 ///
37 /// # Examples
38 ///
39 /// ```rust,no_run
40 /// # use sqlx_core_oldapi::migrate::MigrateError;
41 /// # fn main() -> Result<(), MigrateError> {
42 /// # sqlx_rt::block_on(async move {
43 /// # use sqlx_core_oldapi::migrate::Migrator;
44 /// use std::path::Path;
45 ///
46 /// // Read migrations from a local folder: ./migrations
47 /// let m = Migrator::new(Path::new("./migrations")).await?;
48 /// # Ok(())
49 /// # })
50 /// # }
51 /// ```
52 /// See [MigrationSource] for details on structure of the `./migrations` directory.
53 pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
54 where
55 S: MigrationSource<'s>,
56 {
57 Ok(Self {
58 migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
59 ignore_missing: false,
60 locking: true,
61 })
62 }
63
64 /// Specify whether applied migrations that are missing from the resolved migrations should be ignored.
65 pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
66 self.ignore_missing = ignore_missing;
67 self
68 }
69
70 /// Specify whether or not to lock database during migration. Defaults to `true`.
71 ///
72 /// ### Warning
73 /// Disabling locking can lead to errors or data loss if multiple clients attempt to apply migrations simultaneously
74 /// without some sort of mutual exclusion.
75 ///
76 /// This should only be used if the database does not support locking, e.g. CockroachDB which talks the Postgres
77 /// protocol but does not support advisory locks used by SQLx's migrations support for Postgres.
78 pub fn set_locking(&mut self, locking: bool) -> &Self {
79 self.locking = locking;
80 self
81 }
82
83 /// Get an iterator over all known migrations.
84 pub fn iter(&self) -> slice::Iter<'_, Migration> {
85 self.migrations.iter()
86 }
87
88 /// Run any pending migrations against the database; and, validate previously applied migrations
89 /// against the current migration source to detect accidental changes in previously-applied migrations.
90 ///
91 /// # Examples
92 ///
93 /// ```rust,no_run
94 /// # #[cfg(feature = "sqlite")]
95 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
96 /// # sqlx_rt::block_on(async move {
97 /// # use sqlx_core_oldapi::migrate::Migrator;
98 /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
99 /// let pool = sqlx_core_oldapi::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
100 /// m.run(&pool).await?;
101 /// # Ok(())
102 /// # })
103 /// # }
104 /// ```
105 pub async fn run<'a, A>(&self, migrator: A) -> Result<(), MigrateError>
106 where
107 A: Acquire<'a>,
108 <A::Connection as Deref>::Target: Migrate,
109 {
110 let mut conn = migrator
111 .acquire()
112 .await
113 .map_err(MigrateError::AcquireConnection)?;
114 self.run_direct(&mut *conn).await
115 }
116
117 // Getting around the annoying "implementation of `Acquire` is not general enough" error
118 #[doc(hidden)]
119 pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
120 where
121 C: Migrate,
122 {
123 // lock the database for exclusive access by the migrator
124 if self.locking {
125 conn.lock().await.map_err(MigrateError::AcquireConnection)?;
126 }
127
128 // creates [_migrations] table only if needed
129 // eventually this will likely migrate previous versions of the table
130 conn.ensure_migrations_table()
131 .await
132 .map_err(MigrateError::AccessMigrationMetadata)?;
133
134 let version = conn
135 .dirty_version()
136 .await
137 .map_err(MigrateError::AccessMigrationMetadata)?;
138 if let Some(version) = version {
139 return Err(MigrateError::Dirty(version));
140 }
141
142 let applied_migrations = conn
143 .list_applied_migrations()
144 .await
145 .map_err(MigrateError::AccessMigrationMetadata)?;
146 validate_applied_migrations(&applied_migrations, self)?;
147
148 let applied_migrations: HashMap<_, _> = applied_migrations
149 .into_iter()
150 .map(|m| (m.version, m))
151 .collect();
152
153 for migration in self.iter() {
154 if migration.migration_type.is_down_migration() {
155 continue;
156 }
157
158 match applied_migrations.get(&migration.version) {
159 Some(applied_migration) => {
160 if migration.checksum != applied_migration.checksum {
161 return Err(MigrateError::VersionMismatch(migration.version));
162 }
163 }
164 None => {
165 conn.apply(migration)
166 .await
167 .map_err(|e| MigrateError::Execute(migration.version, e))?;
168 }
169 }
170 }
171
172 // unlock the migrator to allow other migrators to run
173 // but do nothing as we already migrated
174 if self.locking {
175 conn.unlock()
176 .await
177 .map_err(MigrateError::AcquireConnection)?;
178 }
179
180 Ok(())
181 }
182
183 /// Run down migrations against the database until a specific version.
184 ///
185 /// # Examples
186 ///
187 /// ```rust,no_run
188 /// # #[cfg(feature = "sqlite")]
189 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
190 /// # sqlx_rt::block_on(async move {
191 /// # use sqlx_core_oldapi::migrate::Migrator;
192 /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
193 /// let pool = sqlx_core_oldapi::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
194 /// m.undo(&pool, 4).await?;
195 /// # Ok(())
196 /// # })
197 /// # }
198 /// ```
199 pub async fn undo<'a, A>(&self, migrator: A, target: i64) -> Result<(), MigrateError>
200 where
201 A: Acquire<'a>,
202 <A::Connection as Deref>::Target: Migrate,
203 {
204 let mut conn = migrator
205 .acquire()
206 .await
207 .map_err(MigrateError::AcquireConnection)?;
208
209 // lock the database for exclusive access by the migrator
210 if self.locking {
211 conn.lock().await.map_err(MigrateError::AcquireConnection)?;
212 }
213
214 // creates [_migrations] table only if needed
215 // eventually this will likely migrate previous versions of the table
216 conn.ensure_migrations_table()
217 .await
218 .map_err(MigrateError::AccessMigrationMetadata)?;
219
220 let version = conn
221 .dirty_version()
222 .await
223 .map_err(MigrateError::AccessMigrationMetadata)?;
224 if let Some(version) = version {
225 return Err(MigrateError::Dirty(version));
226 }
227
228 let applied_migrations = conn
229 .list_applied_migrations()
230 .await
231 .map_err(MigrateError::AccessMigrationMetadata)?;
232 validate_applied_migrations(&applied_migrations, self)?;
233
234 let applied_migrations: HashMap<_, _> = applied_migrations
235 .into_iter()
236 .map(|m| (m.version, m))
237 .collect();
238
239 for migration in self
240 .iter()
241 .rev()
242 .filter(|m| m.migration_type.is_down_migration())
243 .filter(|m| applied_migrations.contains_key(&m.version))
244 .filter(|m| m.version > target)
245 {
246 conn.revert(migration)
247 .await
248 .map_err(|e| MigrateError::Execute(migration.version, e))?;
249 }
250
251 // unlock the migrator to allow other migrators to run
252 // but do nothing as we already migrated
253 if self.locking {
254 conn.unlock()
255 .await
256 .map_err(MigrateError::AcquireConnection)?;
257 }
258
259 Ok(())
260 }
261}