1use std::path::{Path, PathBuf};
2
3use sentinel_driver::advisory_lock::PgAdvisoryLock;
4use sentinel_driver::{Connection, Pool};
5
6use crate::SNTL_MIGRATE_LOCK_ID;
7use crate::checksum::sha256_of_sql;
8use crate::discover::discover;
9use crate::error::{Error, Result};
10use crate::migration::{Migration, TxMode, Version};
11use crate::tracking;
12
13#[derive(Debug, Default)]
15pub struct MigrationReport {
16 pub applied: Vec<Version>,
17}
18
19#[derive(Debug)]
21pub struct MigrationStatus {
22 pub version: Version,
23 pub state: State,
24 pub checksum: Option<String>,
25}
26
27#[derive(Debug, PartialEq, Eq)]
28pub enum State {
29 Applied,
30 Pending,
31 ChecksumDrift,
32}
33
34#[derive(Debug, Clone)]
35pub struct RefreshConfig {
36 pub conn_str: String,
37 pub cache_dir: PathBuf,
38}
39
40pub struct Migrator {
42 migrations: Vec<Migration>,
43 source: MigrationSource,
44 refresh: Option<RefreshConfig>,
45}
46
47#[derive(Debug)]
48enum MigrationSource {
49 Dir(PathBuf),
50 Static,
51}
52
53impl Migrator {
54 pub fn from_dir(path: impl AsRef<Path>) -> Result<Self> {
55 let path = path.as_ref().to_path_buf();
56 let migrations = discover(&path)?;
57 Ok(Self {
58 migrations,
59 source: MigrationSource::Dir(path),
60 refresh: None,
61 })
62 }
63
64 pub fn from_static(entries: &'static [(&'static str, &'static str, TxMode)]) -> Self {
65 let migrations = entries
66 .iter()
67 .map(|(v, sql, mode)| Migration {
68 version: v
69 .parse()
70 .expect("compile-time migration version must be valid"),
71 sql: (*sql).to_string(),
72 tx_mode: *mode,
73 })
74 .collect();
75 Self {
76 migrations,
77 source: MigrationSource::Static,
78 refresh: None,
79 }
80 }
81
82 pub fn with_refresh(
83 mut self,
84 conn_str: impl Into<String>,
85 cache_dir: impl Into<PathBuf>,
86 ) -> Self {
87 self.refresh = Some(RefreshConfig {
88 conn_str: conn_str.into(),
89 cache_dir: cache_dir.into(),
90 });
91 self
92 }
93
94 pub async fn run(&self, pool: &Pool) -> Result<MigrationReport> {
95 let mut conn = pool.acquire().await?;
96 let lock = PgAdvisoryLock::new(SNTL_MIGRATE_LOCK_ID);
97 let guard = lock.acquire(&mut conn).await?;
98
99 let result = self.run_locked(&mut conn).await;
100
101 guard.release(&mut conn).await?;
102 let report = result?;
103
104 if let Some(cfg) = &self.refresh {
105 crate::refresh::refresh_schema(&cfg.conn_str, &cfg.cache_dir).await?;
106 }
107 Ok(report)
108 }
109
110 async fn run_locked(&self, conn: &mut Connection) -> Result<MigrationReport> {
111 tracking::ensure(conn).await?;
112 let applied = tracking::applied(conn).await?;
113 let applied_set: std::collections::BTreeSet<Version> =
114 applied.iter().map(|(v, _)| v.clone()).collect();
115 let highest_applied = applied_set.iter().max().cloned();
116
117 let mut report = MigrationReport::default();
118 for m in &self.migrations {
119 if applied_set.contains(&m.version) {
120 continue;
121 }
122 if let Some(highest) = &highest_applied {
123 if m.version < *highest {
124 return Err(Error::OutOfOrder {
125 pending: m.version.clone(),
126 highest_applied: highest.clone(),
127 });
128 }
129 }
130
131 let started = std::time::Instant::now();
132 apply_one(conn, m).await?;
133 let checksum = sha256_of_sql(&m.sql);
134 tracking::record(conn, &m.version, &checksum).await?;
135 conn.instrumentation()
136 .on_event(&sentinel_driver::Event::MigrationApply {
137 version: m.version.as_str(),
138 duration: started.elapsed(),
139 checksum: &checksum,
140 });
141 report.applied.push(m.version.clone());
142 }
143 Ok(report)
144 }
145
146 pub async fn info(&self, pool: &Pool) -> Result<Vec<MigrationStatus>> {
147 let mut conn = pool.acquire().await?;
148 tracking::ensure(&mut conn).await?;
149 let applied = tracking::applied(&mut conn).await?;
150 let applied_map: std::collections::BTreeMap<Version, String> =
151 applied.into_iter().collect();
152
153 let mut out = Vec::with_capacity(self.migrations.len() + applied_map.len());
154 for m in &self.migrations {
155 if let Some(recorded) = applied_map.get(&m.version) {
156 let current = sha256_of_sql(&m.sql);
157 let state = if current == *recorded {
158 State::Applied
159 } else {
160 conn.instrumentation()
161 .on_event(&sentinel_driver::Event::MigrationDrift {
162 version: m.version.as_str(),
163 recorded,
164 current: ¤t,
165 });
166 State::ChecksumDrift
167 };
168 out.push(MigrationStatus {
169 version: m.version.clone(),
170 state,
171 checksum: Some(recorded.clone()),
172 });
173 } else {
174 out.push(MigrationStatus {
175 version: m.version.clone(),
176 state: State::Pending,
177 checksum: None,
178 });
179 }
180 }
181 Ok(out)
182 }
183
184 pub fn migrations(&self) -> &[Migration] {
185 &self.migrations
186 }
187
188 pub fn source_path(&self) -> Option<&Path> {
189 match &self.source {
190 MigrationSource::Dir(p) => Some(p.as_path()),
191 MigrationSource::Static => None,
192 }
193 }
194}
195
196async fn apply_one(conn: &mut Connection, m: &Migration) -> Result<()> {
197 match m.tx_mode {
198 TxMode::PerMigration => {
199 conn.execute("BEGIN", &[]).await?;
200 if let Err(e) = conn.execute(&m.sql, &[]).await {
201 conn.execute("ROLLBACK", &[]).await.ok();
202 return Err(Error::ApplyFailed {
203 version: m.version.clone(),
204 source: e,
205 });
206 }
207 conn.execute("COMMIT", &[])
208 .await
209 .map_err(|source| Error::ApplyFailed {
210 version: m.version.clone(),
211 source,
212 })?;
213 }
214 TxMode::None => {
215 conn.execute(&m.sql, &[])
216 .await
217 .map_err(|source| Error::ApplyFailed {
218 version: m.version.clone(),
219 source,
220 })?;
221 }
222 }
223 Ok(())
224}