sqlx_core_guts/migrate/
migrator.rs

1use crate::acquire::Acquire;
2use crate::migrate::{AppliedMigration, Migrate, MigrateError, Migration, MigrationSource};
3use std::borrow::Cow;
4use std::collections::{HashMap, HashSet};
5use std::ops::Deref;
6use std::slice;
7
8#[derive(Debug)]
9pub struct Migrator {
10    pub migrations: Cow<'static, [Migration]>,
11    pub ignore_missing: bool,
12}
13
14fn validate_applied_migrations(
15    applied_migrations: &[AppliedMigration],
16    migrator: &Migrator,
17) -> Result<(), MigrateError> {
18    if migrator.ignore_missing {
19        return Ok(());
20    }
21
22    let migrations: HashSet<_> = migrator.iter().map(|m| m.version).collect();
23
24    for applied_migration in applied_migrations {
25        if !migrations.contains(&applied_migration.version) {
26            return Err(MigrateError::VersionMissing(applied_migration.version));
27        }
28    }
29
30    Ok(())
31}
32
33impl Migrator {
34    /// Creates a new instance with the given source.
35    ///
36    /// # Examples
37    ///
38    /// ```rust,no_run
39    /// # use sqlx_core::migrate::MigrateError;
40    /// # fn main() -> Result<(), MigrateError> {
41    /// # sqlx_rt::block_on(async move {
42    /// # use sqlx_core::migrate::Migrator;
43    /// use std::path::Path;
44    ///
45    /// // Read migrations from a local folder: ./migrations
46    /// let m = Migrator::new(Path::new("./migrations")).await?;
47    /// # Ok(())
48    /// # })
49    /// # }
50    /// ```
51    /// See [MigrationSource] for details on structure of the `./migrations` directory.
52    pub async fn new<'s, S>(source: S) -> Result<Self, MigrateError>
53    where
54        S: MigrationSource<'s>,
55    {
56        Ok(Self {
57            migrations: Cow::Owned(source.resolve().await.map_err(MigrateError::Source)?),
58            ignore_missing: false,
59        })
60    }
61
62    /// Specify should ignore applied migrations that missing in the resolved migrations.
63    pub fn set_ignore_missing(&mut self, ignore_missing: bool) -> &Self {
64        self.ignore_missing = ignore_missing;
65        self
66    }
67
68    /// Get an iterator over all known migrations.
69    pub fn iter(&self) -> slice::Iter<'_, Migration> {
70        self.migrations.iter()
71    }
72
73    /// Run any pending migrations against the database; and, validate previously applied migrations
74    /// against the current migration source to detect accidental changes in previously-applied migrations.
75    ///
76    /// # Examples
77    ///
78    /// ```rust,no_run
79    /// # use sqlx_core::migrate::MigrateError;
80    /// # #[cfg(feature = "sqlite")]
81    /// # fn main() -> Result<(), MigrateError> {
82    /// #     sqlx_rt::block_on(async move {
83    /// # use sqlx_core::migrate::Migrator;
84    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
85    /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
86    /// m.run(&pool).await
87    /// #     })
88    /// # }
89    /// ```
90    pub async fn run<'a, A>(&self, migrator: A) -> Result<(), MigrateError>
91    where
92        A: Acquire<'a>,
93        <A::Connection as Deref>::Target: Migrate,
94    {
95        let mut conn = migrator.acquire().await?;
96
97        // lock the database for exclusive access by the migrator
98        conn.lock().await?;
99
100        // creates [_migrations] table only if needed
101        // eventually this will likely migrate previous versions of the table
102        conn.ensure_migrations_table().await?;
103
104        let version = conn.dirty_version().await?;
105        if let Some(version) = version {
106            return Err(MigrateError::Dirty(version));
107        }
108
109        let applied_migrations = conn.list_applied_migrations().await?;
110        validate_applied_migrations(&applied_migrations, self)?;
111
112        let applied_migrations: HashMap<_, _> = applied_migrations
113            .into_iter()
114            .map(|m| (m.version, m))
115            .collect();
116
117        for migration in self.iter() {
118            if migration.migration_type.is_down_migration() {
119                continue;
120            }
121
122            match applied_migrations.get(&migration.version) {
123                Some(applied_migration) => {
124                    if migration.checksum != applied_migration.checksum {
125                        return Err(MigrateError::VersionMismatch(migration.version));
126                    }
127                }
128                None => {
129                    conn.apply(migration).await?;
130                }
131            }
132        }
133
134        // unlock the migrator to allow other migrators to run
135        // but do nothing as we already migrated
136        conn.unlock().await?;
137
138        Ok(())
139    }
140
141    /// Run down migrations against the database until a specific version.
142    ///
143    /// # Examples
144    ///
145    /// ```rust,no_run
146    /// # use sqlx_core::migrate::MigrateError;
147    /// # #[cfg(feature = "sqlite")]
148    /// # fn main() -> Result<(), MigrateError> {
149    /// #     sqlx_rt::block_on(async move {
150    /// # use sqlx_core::migrate::Migrator;
151    /// let m = Migrator::new(std::path::Path::new("./migrations")).await?;
152    /// let pool = sqlx_core::sqlite::SqlitePoolOptions::new().connect("sqlite::memory:").await?;
153    /// m.undo(&pool, 4).await
154    /// #     })
155    /// # }
156    /// ```
157    pub async fn undo<'a, A>(&self, migrator: A, target: i64) -> Result<(), MigrateError>
158    where
159        A: Acquire<'a>,
160        <A::Connection as Deref>::Target: Migrate,
161    {
162        let mut conn = migrator.acquire().await?;
163
164        // lock the database for exclusive access by the migrator
165        conn.lock().await?;
166
167        // creates [_migrations] table only if needed
168        // eventually this will likely migrate previous versions of the table
169        conn.ensure_migrations_table().await?;
170
171        let version = conn.dirty_version().await?;
172        if let Some(version) = version {
173            return Err(MigrateError::Dirty(version));
174        }
175
176        let applied_migrations = conn.list_applied_migrations().await?;
177        validate_applied_migrations(&applied_migrations, self)?;
178
179        let applied_migrations: HashMap<_, _> = applied_migrations
180            .into_iter()
181            .map(|m| (m.version, m))
182            .collect();
183
184        for migration in self
185            .iter()
186            .rev()
187            .filter(|m| m.migration_type.is_down_migration())
188            .filter(|m| applied_migrations.contains_key(&m.version))
189            .filter(|m| m.version > target)
190        {
191            conn.revert(migration).await?;
192        }
193
194        // unlock the migrator to allow other migrators to run
195        // but do nothing as we already migrated
196        conn.unlock().await?;
197
198        Ok(())
199    }
200}