1use crate::error::{DatabaseError as _, Error, TernResult};
9use crate::migration::{AppliedMigration, Migration, MigrationContext, MigrationId};
10
11use chrono::{DateTime, Utc};
12use display_json::{DebugAsJson, DisplayAsJsonPretty};
13use serde::Serialize;
14use std::collections::HashSet;
15use std::fmt::Write;
16
17pub struct Runner<C: MigrationContext> {
19 context: C,
20}
21
22impl<C> Runner<C>
23where
24 C: MigrationContext,
25{
26 pub fn new(context: C) -> Self {
28 Self { context }
29 }
30
31 pub async fn init_history(&mut self) -> TernResult<()> {
33 self.context.check_history_table().await
34 }
35
36 pub async fn drop_history(&mut self) -> TernResult<()> {
38 self.context.drop_history_table().await
39 }
40
41 async fn validate_source(&mut self) -> TernResult<()> {
43 self.context.check_history_table().await?;
44 let applied: HashSet<MigrationId> = self
45 .context
46 .previously_applied()
47 .await?
48 .into_iter()
49 .map(MigrationId::from)
50 .collect();
51 let source: HashSet<MigrationId> = self
52 .context
53 .migration_set(None)
54 .migration_ids()
55 .into_iter()
56 .collect();
57
58 check_migrations_in_sync(applied, source)
59 }
60
61 fn validate_target(
63 &self,
64 last_applied: Option<i64>,
65 target_version: Option<i64>,
66 ) -> TernResult<()> {
67 let Some(source) = self.context.migration_set(None).max() else {
68 return Ok(());
69 };
70 if let Some(target) = target_version {
71 match last_applied {
72 Some(applied) if target < applied => Err(Error::Invalid(format!(
73 "target version V{target} earlier than latest applied version V{applied}",
74 )))?,
75 _ if target > source => Err(Error::Invalid(format!(
76 "target version V{target} does not exist, latest version found was V{source}",
77 )))?,
78 _ => Ok(()),
79 }
80 } else {
81 Ok(())
82 }
83 }
84
85 pub async fn run_apply(
87 &mut self,
88 target_version: Option<i64>,
89 dryrun: bool,
90 ) -> TernResult<Report> {
91 self.validate_source().await?;
92 let last_applied = self.context.latest_version().await?;
93 self.validate_target(last_applied, target_version)?;
94
95 let unapplied = self.context.migration_set(last_applied);
96
97 let mut results = Vec::new();
98 for migration in &unapplied.migrations {
99 let id = migration.migration_id();
100 let ver = migration.version();
101
102 if matches!(target_version, Some(end) if ver > end) {
104 break;
105 }
106
107 let result = if dryrun {
108 let query = migration
110 .build(&mut self.context)
111 .await
112 .with_report(&results)?;
113
114 MigrationResult::from_unapplied(migration.as_ref(), query.sql())
115 } else {
116 log::trace!("applying migration {id}");
117
118 self.context
119 .apply(migration.as_ref())
120 .await
121 .tern_migration_result(migration.as_ref())
122 .with_report(&results)
123 .map(|v| MigrationResult::from_applied(&v, Some(migration.no_tx())))?
124 };
125
126 results.push(result);
127 }
128
129 Ok(Report::new(results))
130 }
131
132 #[deprecated(since = "3.1.0", note = "use `run_apply_all`")]
134 pub async fn apply_all(&mut self) -> TernResult<Report> {
135 self.run_apply(None, false).await
136 }
137
138 pub async fn run_apply_all(&mut self, dryrun: bool) -> TernResult<Report> {
140 self.run_apply(None, dryrun).await
141 }
142
143 pub async fn list_applied(&mut self) -> TernResult<Report> {
145 self.validate_source().await?;
146
147 let applied = self
148 .context
149 .previously_applied()
150 .await?
151 .iter()
152 .map(|m| MigrationResult::from_applied(m, None))
153 .collect::<Vec<_>>();
154 let report = Report::new(applied);
155
156 Ok(report)
157 }
158
159 #[deprecated(since = "3.1.0", note = "no valid use case for `start_version`")]
160 pub async fn soft_apply(
161 &mut self,
162 start_version: Option<i64>,
163 target_version: Option<i64>,
164 ) -> TernResult<Report> {
165 if start_version.is_some() {
166 return Err(Error::Invalid(
167 "no valid `start_version` other than the first unapplied, use `run_soft_apply`"
168 .into(),
169 ));
170 }
171 self.run_soft_apply(target_version, false).await
172 }
173
174 pub async fn run_soft_apply(
182 &mut self,
183 target_version: Option<i64>,
184 dryrun: bool,
185 ) -> TernResult<Report> {
186 self.validate_source().await?;
187 let last_applied = self.context.latest_version().await?;
188 self.validate_target(last_applied, target_version)?;
189
190 let unapplied = self.context.migration_set(last_applied);
191
192 let mut results = Vec::new();
193 for migration in &unapplied.migrations {
194 let id = migration.migration_id();
195 let ver = migration.version();
196
197 if matches!(target_version, Some(end) if ver > end) {
199 break;
200 }
201
202 let query = migration
204 .build(&mut self.context)
205 .await
206 .with_report(&results)?;
207 let mut content = String::from("-- SOFT APPLIED:\n\n");
208 writeln!(content, "{query}")?;
209
210 let applied = migration.to_applied(0, Utc::now(), &content);
211 let result = MigrationResult::from_soft_applied(&applied, dryrun);
212
213 if !dryrun {
214 log::trace!("soft applying migration {id}");
215 self.context
216 .insert_applied(&applied)
217 .await
218 .with_report(&results)?;
219 }
220
221 results.push(result);
222 }
223 let report = Report::new(results);
224
225 Ok(report)
226 }
227}
228
229#[derive(Clone, Serialize, DebugAsJson, DisplayAsJsonPretty, Default)]
231pub struct Report {
232 migrations: Vec<MigrationResult>,
233}
234
235impl Report {
236 pub fn new(migrations: Vec<MigrationResult>) -> Self {
237 Self { migrations }
238 }
239
240 pub fn count(&self) -> usize {
241 self.migrations.len()
242 }
243
244 pub fn results(&self) -> Vec<MigrationResult> {
246 self.migrations.clone()
247 }
248
249 pub fn iter_results(&self) -> impl Iterator<Item = MigrationResult> {
251 self.migrations.clone().into_iter()
252 }
253}
254
255#[derive(Clone, Serialize, DebugAsJson, DisplayAsJsonPretty)]
258#[allow(dead_code)]
259pub struct MigrationResult {
260 dryrun: bool,
261 version: i64,
262 state: MigrationState,
263 applied_at: Option<DateTime<Utc>>,
264 description: String,
265 content: String,
266 transactional: Transactional,
267 duration_ms: RunDuration,
268}
269
270impl MigrationResult {
271 pub(crate) fn from_applied(applied: &AppliedMigration, no_tx: Option<bool>) -> Self {
272 Self {
273 dryrun: false,
274 version: applied.version,
275 state: MigrationState::Applied,
276 applied_at: Some(applied.applied_at),
277 description: applied.description.clone(),
278 content: applied.content.clone(),
279 transactional: no_tx
280 .map(Transactional::from_boolean)
281 .unwrap_or(Transactional::Other("Previously applied".to_string())),
282 duration_ms: RunDuration::Duration(applied.duration_ms),
283 }
284 }
285
286 pub(crate) fn from_soft_applied(applied: &AppliedMigration, dryrun: bool) -> Self {
287 Self {
288 dryrun,
289 version: applied.version,
290 state: MigrationState::SoftApplied,
291 applied_at: Some(applied.applied_at),
292 description: applied.description.clone(),
293 content: applied.content.clone(),
294 transactional: Transactional::Other("Soft applied".to_string()),
295 duration_ms: RunDuration::Duration(applied.duration_ms),
296 }
297 }
298
299 pub(crate) fn from_unapplied<M>(migration: &M, content: &str) -> Self
300 where
301 M: Migration + ?Sized,
302 {
303 Self {
304 dryrun: true,
305 version: migration.version(),
306 state: MigrationState::Unapplied,
307 applied_at: None,
308 description: migration.migration_id().description(),
309 content: content.into(),
310 transactional: Transactional::from_boolean(migration.no_tx()),
311 duration_ms: RunDuration::Unapplied,
312 }
313 }
314}
315
316#[derive(Debug, Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Serialize)]
317enum MigrationState {
318 Applied,
319 SoftApplied,
320 Unapplied,
321}
322
323#[derive(Debug, Clone, Serialize)]
324enum Transactional {
325 NoTransaction,
326 InTransaction,
327 Other(String),
328}
329
330impl Transactional {
331 fn from_boolean(v: bool) -> Self {
332 if v {
333 return Self::NoTransaction;
334 };
335 Self::InTransaction
336 }
337}
338
339#[derive(Debug, Clone, Copy, Serialize)]
340enum RunDuration {
341 Duration(i64),
342 Unapplied,
343}
344
345impl std::fmt::Display for Transactional {
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 match self {
348 Self::NoTransaction => write!(f, "No Transaction"),
349 Self::InTransaction => write!(f, "In Transaction"),
350 Self::Other(s) => write!(f, "{s}"),
351 }
352 }
353}
354
355impl std::fmt::Display for MigrationState {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 match self {
358 Self::Applied => write!(f, "Applied"),
359 Self::SoftApplied => write!(f, "Soft Applied"),
360 Self::Unapplied => write!(f, "Not Applied"),
361 }
362 }
363}
364
365impl std::fmt::Display for RunDuration {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 match self {
368 Self::Duration(ms) => write!(f, "{}ms", ms),
369 Self::Unapplied => write!(f, "Not Applied"),
370 }
371 }
372}
373
374fn check_migrations_in_sync(
376 applied: HashSet<MigrationId>,
377 source: HashSet<MigrationId>,
378) -> TernResult<()> {
379 let source_not_found: Vec<&MigrationId> = applied.difference(&source).collect();
380
381 if !source_not_found.is_empty() {
382 return Err(Error::OutOfSync {
383 at_issue: source_not_found.into_iter().cloned().collect(),
384 msg: "version/name applied but missing in source".into(),
385 });
386 }
387
388 Ok(())
389}
390
391#[cfg(test)]
392mod tests {
393 use super::Error;
394 use super::MigrationId;
395
396 use std::collections::HashSet;
397
398 #[test]
399 fn missing_source() {
400 let source: HashSet<MigrationId> = vec![
401 MigrationId::new(1, "first".into()),
402 MigrationId::new(2, "second".into()),
403 MigrationId::new(3, "fourth".into()),
404 ]
405 .into_iter()
406 .collect();
407 let applied: HashSet<MigrationId> = vec![
408 MigrationId::new(1, "first".into()),
409 MigrationId::new(2, "second".into()),
410 MigrationId::new(3, "third".into()),
411 ]
412 .into_iter()
413 .collect();
414 let missing = vec![MigrationId::new(3, "third".into())];
415 let result = super::check_migrations_in_sync(applied, source);
416 assert!(result.is_err());
417 let err = result.unwrap_err();
418 assert!(matches!(err, Error::OutOfSync { at_issue, .. } if *at_issue == missing));
419 }
420
421 #[test]
422 fn fewer_in_source() {
423 let source: HashSet<MigrationId> = vec![
424 MigrationId::new(1, "first".into()),
425 MigrationId::new(2, "second".into()),
426 MigrationId::new(3, "third".into()),
427 ]
428 .into_iter()
429 .collect();
430 let applied: HashSet<MigrationId> = vec![
431 MigrationId::new(1, "first".into()),
432 MigrationId::new(2, "second".into()),
433 MigrationId::new(3, "third".into()),
434 MigrationId::new(4, "fourth".into()),
435 ]
436 .into_iter()
437 .collect();
438 let missing = vec![MigrationId::new(4, "fourth".into())];
439 let result = super::check_migrations_in_sync(applied, source);
440 assert!(result.is_err());
441 let err = result.unwrap_err();
442 assert!(matches!(err, Error::OutOfSync { at_issue, .. } if *at_issue == missing));
443 }
444
445 #[test]
446 fn mismatched_source() {
447 let source: HashSet<MigrationId> = vec![
448 MigrationId::new(1, "first".into()),
449 MigrationId::new(2, "second".into()),
450 MigrationId::new(3, "third".into()),
451 MigrationId::new(4, "fifth".into()),
452 MigrationId::new(5, "sixth".into()),
453 MigrationId::new(6, "seventh".into()),
454 MigrationId::new(7, "eighth".into()),
455 ]
456 .into_iter()
457 .collect();
458 let applied: HashSet<MigrationId> = vec![
459 MigrationId::new(1, "first".into()),
460 MigrationId::new(2, "second".into()),
461 MigrationId::new(3, "third".into()),
462 MigrationId::new(4, "fourth".into()),
463 MigrationId::new(5, "fifth".into()),
464 ]
465 .into_iter()
466 .collect();
467 let divergence = vec![
468 MigrationId::new(4, "fourth".into()),
469 MigrationId::new(5, "fifth".into()),
470 ];
471 let result = super::check_migrations_in_sync(applied, source);
472 assert!(result.is_err());
473 let err = result.unwrap_err();
474 assert!(matches!(err, Error::OutOfSync { at_issue, .. } if *at_issue == divergence));
475 }
476}