sqlx_core_oldapi/migrate/
migrator.rs

1use crate::acquire::Acquire;
2use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource};
3use std::borrow::Cow;
4use std::collections::{HashMap, HashSet};
5use std::ops::Deref;
6use std::slice;
7
8#[derive(Debug)]
9pub struct Migrator {
10    pub migrations: Cow<'static, [Migration]>,
11    pub ignore_missing: bool,
12    pub locking: bool,
13}
14
15fn validate_applied_migrations(
16    applied_migrations: &[AppliedMigration],
17    migrator: &Migrator,
18) -> Result<(), MigrateError> {
19    if migrator.ignore_missing {
20        return Ok(());
21    }
22
23    let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
24
25    for applied_migration in applied_migrations {
26        if !migrations.contains(&applied_migration.version) {
27            return Err(MigrateError::VersionMissing(applied_migration.version));
28        }
29    }
30
31    Ok(())
32}
33
34impl Migrator {
35    /// Creates a new instance with the given source.
36    ///
37    /// # Examples
38    ///
39    /// ```rust,no_run
40    /// # use sqlx_core_oldapi::migrate::MigrateError;
41    /// # fn main() -> Result<(), MigrateError> {
42    /// # sqlx_rt::block_on(async move {
43    /// # use sqlx_core_oldapi::migrate::Migrator;
44    /// use std::path::Path;
45    ///
46    /// // Read migrations from a local folder: ./migrations
47    /// let m = Migrator::new(Path::new("./migrations")).await?;
48    /// # Ok(())
49    /// # })
50    /// # }
51    /// ```
52    /// See [MigrationSource] for details on structure of the `./migrations` directory.
53    pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
54    where
55        S: MigrationSource<'s>,
56    {
57        Ok(Self {
58            migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
59            ignore_missing: false,
60            locking: true,
61        })
62    }
63
64    /// Specify whether applied migrations that are missing from the resolved migrations should be ignored.
65    pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
66        self.ignore_missing = ignore_missing;
67        self
68    }
69
70    /// Specify whether or not to lock database during migration. Defaults to `true`.
71    ///
72    /// ### Warning
73    /// Disabling locking can lead to errors or data loss if multiple clients attempt to apply migrations simultaneously
74    /// without some sort of mutual exclusion.
75    ///
76    /// This should only be used if the database does not support locking, e.g. CockroachDB which talks the Postgres
77    /// protocol but does not support advisory locks used by SQLx's migrations support for Postgres.
78    pub fn set_locking(&mut self, locking: bool) -> &Self {
79        self.locking = locking;
80        self
81    }
82
83    /// Get an iterator over all known migrations.
84    pub fn iter(&self) -> slice::Iter<'_, Migration> {
85        self.migrations.iter()
86    }
87
88    /// Run any pending migrations against the database; and, validate previously applied migrations
89    /// against the current migration source to detect accidental changes in previously-applied migrations.
90    ///
91    /// # Examples
92    ///
93    /// ```rust,no_run
94    /// # #[cfg(feature = "sqlite")]
95    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
96    /// #     sqlx_rt::block_on(async move {
97    /// # use sqlx_core_oldapi::migrate::Migrator;
98    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
99    /// let pool = sqlx_core_oldapi::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
100    /// m.run(&pool).await?;
101    /// #     Ok(())
102    /// #     })
103    /// # }
104    /// ```
105    pub async fn run<'a, A>(&self, migrator: A) -> Result<(), MigrateError>
106    where
107        A: Acquire<'a>,
108        <A::Connection as Deref>::Target: Migrate,
109    {
110        let mut conn = migrator
111            .acquire()
112            .await
113            .map_err(MigrateError::AcquireConnection)?;
114        self.run_direct(&mut *conn).await
115    }
116
117    // Getting around the annoying "implementation of `Acquire` is not general enough" error
118    #[doc(hidden)]
119    pub async fn run_direct<C>(&self, conn: &mut C) -> Result<(), MigrateError>
120    where
121        C: Migrate,
122    {
123        // lock the database for exclusive access by the migrator
124        if self.locking {
125            conn.lock().await.map_err(MigrateError::AcquireConnection)?;
126        }
127
128        // creates [_migrations] table only if needed
129        // eventually this will likely migrate previous versions of the table
130        conn.ensure_migrations_table()
131            .await
132            .map_err(MigrateError::AccessMigrationMetadata)?;
133
134        let version = conn
135            .dirty_version()
136            .await
137            .map_err(MigrateError::AccessMigrationMetadata)?;
138        if let Some(version) = version {
139            return Err(MigrateError::Dirty(version));
140        }
141
142        let applied_migrations = conn
143            .list_applied_migrations()
144            .await
145            .map_err(MigrateError::AccessMigrationMetadata)?;
146        validate_applied_migrations(&applied_migrations, self)?;
147
148        let applied_migrations: HashMap<_, _> = applied_migrations
149            .into_iter()
150            .map(|m| (m.version, m))
151            .collect();
152
153        for migration in self.iter() {
154            if migration.migration_type.is_down_migration() {
155                continue;
156            }
157
158            match applied_migrations.get(&migration.version) {
159                Some(applied_migration) => {
160                    if migration.checksum != applied_migration.checksum {
161                        return Err(MigrateError::VersionMismatch(migration.version));
162                    }
163                }
164                None => {
165                    conn.apply(migration)
166                        .await
167                        .map_err(|e| MigrateError::Execute(migration.version, e))?;
168                }
169            }
170        }
171
172        // unlock the migrator to allow other migrators to run
173        // but do nothing as we already migrated
174        if self.locking {
175            conn.unlock()
176                .await
177                .map_err(MigrateError::AcquireConnection)?;
178        }
179
180        Ok(())
181    }
182
183    /// Run down migrations against the database until a specific version.
184    ///
185    /// # Examples
186    ///
187    /// ```rust,no_run
188    /// # #[cfg(feature = "sqlite")]
189    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
190    /// #     sqlx_rt::block_on(async move {
191    /// # use sqlx_core_oldapi::migrate::Migrator;
192    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
193    /// let pool = sqlx_core_oldapi::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
194    /// m.undo(&pool, 4).await?;
195    /// #     Ok(())
196    /// #     })
197    /// # }
198    /// ```
199    pub async fn undo<'a, A>(&self, migrator: A, target: i64) -> Result<(), MigrateError>
200    where
201        A: Acquire<'a>,
202        <A::Connection as Deref>::Target: Migrate,
203    {
204        let mut conn = migrator
205            .acquire()
206            .await
207            .map_err(MigrateError::AcquireConnection)?;
208
209        // lock the database for exclusive access by the migrator
210        if self.locking {
211            conn.lock().await.map_err(MigrateError::AcquireConnection)?;
212        }
213
214        // creates [_migrations] table only if needed
215        // eventually this will likely migrate previous versions of the table
216        conn.ensure_migrations_table()
217            .await
218            .map_err(MigrateError::AccessMigrationMetadata)?;
219
220        let version = conn
221            .dirty_version()
222            .await
223            .map_err(MigrateError::AccessMigrationMetadata)?;
224        if let Some(version) = version {
225            return Err(MigrateError::Dirty(version));
226        }
227
228        let applied_migrations = conn
229            .list_applied_migrations()
230            .await
231            .map_err(MigrateError::AccessMigrationMetadata)?;
232        validate_applied_migrations(&applied_migrations, self)?;
233
234        let applied_migrations: HashMap<_, _> = applied_migrations
235            .into_iter()
236            .map(|m| (m.version, m))
237            .collect();
238
239        for migration in self
240            .iter()
241            .rev()
242            .filter(|m| m.migration_type.is_down_migration())
243            .filter(|m| applied_migrations.contains_key(&m.version))
244            .filter(|m| m.version > target)
245        {
246            conn.revert(migration)
247                .await
248                .map_err(|e| MigrateError::Execute(migration.version, e))?;
249        }
250
251        // unlock the migrator to allow other migrators to run
252        // but do nothing as we already migrated
253        if self.locking {
254            conn.unlock()
255                .await
256                .map_err(MigrateError::AcquireConnection)?;
257        }
258
259        Ok(())
260    }
261}