Skip to main content

sntl_migrate/
runner.rs

1use std::path::{Path, PathBuf};
2
3use sentinel_driver::advisory_lock::PgAdvisoryLock;
4use sentinel_driver::{Connection, Pool};
5
6use crate::SNTL_MIGRATE_LOCK_ID;
7use crate::checksum::sha256_of_sql;
8use crate::discover::discover;
9use crate::error::{Error, Result};
10use crate::migration::{Migration, TxMode, Version};
11use crate::tracking;
12
13/// Result of a single `Migrator::run` invocation.
14#[derive(Debug, Default)]
15pub struct MigrationReport {
16    pub applied: Vec<Version>,
17}
18
19/// One row in `sntl migrate info`.
20#[derive(Debug)]
21pub struct MigrationStatus {
22    pub version: Version,
23    pub state: State,
24    pub checksum: Option<String>,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub enum State {
29    Applied,
30    Pending,
31    ChecksumDrift,
32}
33
34#[derive(Debug, Clone)]
35pub struct RefreshConfig {
36    pub conn_str: String,
37    pub cache_dir: PathBuf,
38}
39
40/// Top-level migration runner.
41pub struct Migrator {
42    migrations: Vec<Migration>,
43    source: MigrationSource,
44    refresh: Option<RefreshConfig>,
45}
46
47#[derive(Debug)]
48enum MigrationSource {
49    Dir(PathBuf),
50    Static,
51}
52
53impl Migrator {
54    pub fn from_dir(path: impl AsRef<Path>) -> Result<Self> {
55        let path = path.as_ref().to_path_buf();
56        let migrations = discover(&path)?;
57        Ok(Self {
58            migrations,
59            source: MigrationSource::Dir(path),
60            refresh: None,
61        })
62    }
63
64    pub fn from_static(entries: &'static [(&'static str, &'static str, TxMode)]) -> Self {
65        let migrations = entries
66            .iter()
67            .map(|(v, sql, mode)| Migration {
68                version: v
69                    .parse()
70                    .expect("compile-time migration version must be valid"),
71                sql: (*sql).to_string(),
72                tx_mode: *mode,
73            })
74            .collect();
75        Self {
76            migrations,
77            source: MigrationSource::Static,
78            refresh: None,
79        }
80    }
81
82    pub fn with_refresh(
83        mut self,
84        conn_str: impl Into<String>,
85        cache_dir: impl Into<PathBuf>,
86    ) -> Self {
87        self.refresh = Some(RefreshConfig {
88            conn_str: conn_str.into(),
89            cache_dir: cache_dir.into(),
90        });
91        self
92    }
93
94    pub async fn run(&self, pool: &Pool) -> Result<MigrationReport> {
95        let mut conn = pool.acquire().await?;
96        let lock = PgAdvisoryLock::new(SNTL_MIGRATE_LOCK_ID);
97        let guard = lock.acquire(&mut conn).await?;
98
99        let result = self.run_locked(&mut conn).await;
100
101        guard.release(&mut conn).await?;
102        let report = result?;
103
104        if let Some(cfg) = &self.refresh {
105            crate::refresh::refresh_schema(&cfg.conn_str, &cfg.cache_dir).await?;
106        }
107        Ok(report)
108    }
109
110    async fn run_locked(&self, conn: &mut Connection) -> Result<MigrationReport> {
111        tracking::ensure(conn).await?;
112        let applied = tracking::applied(conn).await?;
113        let applied_set: std::collections::BTreeSet<Version> =
114            applied.iter().map(|(v, _)| v.clone()).collect();
115        let highest_applied = applied_set.iter().max().cloned();
116
117        let mut report = MigrationReport::default();
118        for m in &self.migrations {
119            if applied_set.contains(&m.version) {
120                continue;
121            }
122            if let Some(highest) = &highest_applied {
123                if m.version < *highest {
124                    return Err(Error::OutOfOrder {
125                        pending: m.version.clone(),
126                        highest_applied: highest.clone(),
127                    });
128                }
129            }
130
131            let started = std::time::Instant::now();
132            apply_one(conn, m).await?;
133            let checksum = sha256_of_sql(&m.sql);
134            tracking::record(conn, &m.version, &checksum).await?;
135            conn.instrumentation()
136                .on_event(&sentinel_driver::Event::MigrationApply {
137                    version: m.version.as_str(),
138                    duration: started.elapsed(),
139                    checksum: &checksum,
140                });
141            report.applied.push(m.version.clone());
142        }
143        Ok(report)
144    }
145
146    pub async fn info(&self, pool: &Pool) -> Result<Vec<MigrationStatus>> {
147        let mut conn = pool.acquire().await?;
148        tracking::ensure(&mut conn).await?;
149        let applied = tracking::applied(&mut conn).await?;
150        let applied_map: std::collections::BTreeMap<Version, String> =
151            applied.into_iter().collect();
152
153        let mut out = Vec::with_capacity(self.migrations.len() + applied_map.len());
154        for m in &self.migrations {
155            if let Some(recorded) = applied_map.get(&m.version) {
156                let current = sha256_of_sql(&m.sql);
157                let state = if current == *recorded {
158                    State::Applied
159                } else {
160                    conn.instrumentation()
161                        .on_event(&sentinel_driver::Event::MigrationDrift {
162                            version: m.version.as_str(),
163                            recorded,
164                            current: &current,
165                        });
166                    State::ChecksumDrift
167                };
168                out.push(MigrationStatus {
169                    version: m.version.clone(),
170                    state,
171                    checksum: Some(recorded.clone()),
172                });
173            } else {
174                out.push(MigrationStatus {
175                    version: m.version.clone(),
176                    state: State::Pending,
177                    checksum: None,
178                });
179            }
180        }
181        Ok(out)
182    }
183
184    pub fn migrations(&self) -> &[Migration] {
185        &self.migrations
186    }
187
188    pub fn source_path(&self) -> Option<&Path> {
189        match &self.source {
190            MigrationSource::Dir(p) => Some(p.as_path()),
191            MigrationSource::Static => None,
192        }
193    }
194}
195
196async fn apply_one(conn: &mut Connection, m: &Migration) -> Result<()> {
197    match m.tx_mode {
198        TxMode::PerMigration => {
199            conn.execute("BEGIN", &[]).await?;
200            if let Err(e) = conn.execute(&m.sql, &[]).await {
201                conn.execute("ROLLBACK", &[]).await.ok();
202                return Err(Error::ApplyFailed {
203                    version: m.version.clone(),
204                    source: e,
205                });
206            }
207            conn.execute("COMMIT", &[])
208                .await
209                .map_err(|source| Error::ApplyFailed {
210                    version: m.version.clone(),
211                    source,
212                })?;
213        }
214        TxMode::None => {
215            conn.execute(&m.sql, &[])
216                .await
217                .map_err(|source| Error::ApplyFailed {
218                    version: m.version.clone(),
219                    source,
220                })?;
221        }
222    }
223    Ok(())
224}