1#![warn(clippy::all)]
10#![forbid(unsafe_code)]
11
12use std::collections::{HashMap, HashSet};
13use std::fmt::{Debug, Display};
14use std::hash::Hash;
15use std::rc::Rc;
16use std::sync::Arc;
17
18use daggy::petgraph::EdgeDirection;
19use daggy::{Dag, Walker};
20use indexmap::IndexSet;
21use log::{debug, info};
22use thiserror::Error;
23
24use crate::traversal::DfsPostOrderDirectional;
25
26#[macro_use]
27pub mod testing;
28mod traversal;
29
30pub trait Migration<I> {
34 fn id(&self) -> I;
36
37 fn dependencies(&self) -> HashSet<I>;
39
40 fn description(&self) -> &'static str;
42}
43
44impl<I, T> Migration<I> for Box<T>
45where
46 T: Migration<I> + ?Sized,
47{
48 fn id(&self) -> I {
49 self.as_ref().id()
50 }
51
52 fn dependencies(&self) -> HashSet<I> {
53 self.as_ref().dependencies()
54 }
55
56 fn description(&self) -> &'static str {
57 self.as_ref().description()
58 }
59}
60
61impl<I, T> Migration<I> for Rc<T>
62where
63 T: Migration<I> + ?Sized,
64{
65 fn id(&self) -> I {
66 self.as_ref().id()
67 }
68
69 fn dependencies(&self) -> HashSet<I> {
70 self.as_ref().dependencies()
71 }
72
73 fn description(&self) -> &'static str {
74 self.as_ref().description()
75 }
76}
77
78impl<I, T> Migration<I> for Arc<T>
79where
80 T: Migration<I> + ?Sized,
81{
82 fn id(&self) -> I {
83 self.as_ref().id()
84 }
85
86 fn dependencies(&self) -> HashSet<I> {
87 self.as_ref().dependencies()
88 }
89
90 fn description(&self) -> &'static str {
91 self.as_ref().description()
92 }
93}
94
95#[macro_export]
129macro_rules! migration {
130 ($name:ident, $id:expr, [ $( $dependency_id:expr ),*], $description:expr) => {
131 migration!(::uuid::Uuid, $name, $id, [$($dependency_id),*], $description);
132 };
133 ($ty:path, $name:ident, $id:expr, [ $( $dependency_id:expr ),*], $description:expr) => {
134 impl $crate::Migration<$ty> for $name
135 {
136 fn id(&self) -> $ty {
137 $id
138 }
139
140 fn dependencies(&self) -> ::std::collections::HashSet<$ty> {
141 ::std::collections::HashSet::from([
142 $(
143 $dependency_id,
144 )*
145 ])
146 }
147
148 fn description(&self) -> &'static str {
149 $description
150 }
151 }
152 };
153}
154
155#[derive(Debug)]
157pub enum MigrationDirection {
158 Up,
159 Down,
160}
161
162impl Display for MigrationDirection {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 let printable = match *self {
165 MigrationDirection::Up => "up",
166 MigrationDirection::Down => "Down",
167 };
168 write!(f, "{}", printable)
169 }
170}
171
172pub trait Adapter<I> {
175 type MigrationType: Migration<I>;
177
178 type Error: std::error::Error + 'static;
180
181 fn applied_migrations(&mut self) -> Result<HashSet<I>, Self::Error>;
183
184 fn apply_migration(&mut self, _: &Self::MigrationType) -> Result<(), Self::Error>;
186
187 fn revert_migration(&mut self, _: &Self::MigrationType) -> Result<(), Self::Error>;
189}
190
191#[derive(Debug, Error)]
193pub enum DependencyError<I> {
194 #[error("Duplicate migration ID {0}")]
195 DuplicateId(I),
196 #[error("Unknown migration ID {0}")]
197 UnknownId(I),
198 #[error("Cyclic dependency caused by edge from migration IDs {from} to {to}")]
199 Cycle { from: I, to: I },
200}
201
202#[derive(Debug, Error)]
205pub enum MigratorError<I, T: std::error::Error + 'static> {
206 #[error("An error occurred due to migration dependencies")]
207 Dependency(#[source] DependencyError<I>),
208 #[error("An error occurred while interacting with the adapter.")]
209 Adapter(#[from] T),
210 #[error(
211 "An error occurred while applying migration {id} ({description}) {direction}: {error}."
212 )]
213 Migration {
214 id: I,
215 description: &'static str,
216 direction: MigrationDirection,
217 #[source]
218 error: T,
219 },
220}
221
222pub struct Migrator<I, T: Adapter<I>> {
224 adapter: T,
225 dependencies: Dag<T::MigrationType, ()>,
226 id_map: HashMap<I, daggy::NodeIndex>,
227}
228
229impl<I, T> Migrator<I, T>
230where
231 I: Hash + Display + Eq + Clone,
232 T: Adapter<I>,
233{
234 pub fn new(adapter: T) -> Migrator<I, T> {
236 Migrator {
237 adapter,
238 dependencies: Dag::new(),
239 id_map: HashMap::new(),
240 }
241 }
242
243 pub fn register(
245 &mut self,
246 migration: T::MigrationType,
247 ) -> Result<(), MigratorError<I, T::Error>> {
248 let id = migration.id();
249 debug!("Registering migration {}", id);
250 if self.id_map.contains_key(&id) {
251 return Err(MigratorError::Dependency(DependencyError::DuplicateId(id)));
252 }
253
254 let migration_idx = self.dependencies.add_node(migration);
255 self.id_map.insert(id, migration_idx);
256
257 Ok(())
258 }
259
260 pub fn register_multiple(
262 &mut self,
263 migrations: impl Iterator<Item = T::MigrationType>,
264 ) -> Result<(), MigratorError<I, T::Error>> {
265 for migration in migrations {
266 let id = migration.id();
267 debug!("Registering migration (with multiple) {}", id);
268 if self.id_map.contains_key(&id) {
269 return Err(MigratorError::Dependency(DependencyError::DuplicateId(id)));
270 }
271
272 let migration_idx = self.dependencies.add_node(migration);
273 self.id_map.insert(id, migration_idx);
274 }
275
276 Ok(())
277 }
278
279 fn register_edges(&mut self) -> Result<(), MigratorError<I, T::Error>> {
281 for (id, migration_idx) in self.id_map.iter() {
282 let depends = self
283 .dependencies
284 .node_weight(*migration_idx)
285 .expect("We registered these nodes")
286 .dependencies();
287
288 for d in depends {
289 let parent_idx = self.id_map.get(&d).ok_or_else(|| {
290 MigratorError::Dependency(DependencyError::UnknownId(d.clone()))
291 })?;
292 self.dependencies
293 .add_edge(*parent_idx, *migration_idx, ())
294 .map_err(|_| {
295 MigratorError::Dependency(DependencyError::Cycle {
296 from: d,
297 to: id.clone(),
298 })
299 })?;
300 }
301 }
302 Ok(())
303 }
304
305 fn induced_stream(
311 &self,
312 id: Option<I>,
313 dir: EdgeDirection,
314 ) -> Result<IndexSet<daggy::NodeIndex>, DependencyError<I>> {
315 let mut to_visit = Vec::new();
316 match id {
317 Some(id) => {
318 if let Some(id) = self.id_map.get(&id) {
319 to_visit.push(*id);
320 } else {
321 return Err(DependencyError::UnknownId(id));
322 }
323 }
324 None => to_visit.extend(self.dependencies.graph().externals(dir.opposite())),
325 }
326
327 let mut target_set = IndexSet::new();
328
329 for idx in to_visit {
330 if !target_set.contains(&idx) {
331 let walker = DfsPostOrderDirectional::new(dir, &self.dependencies, idx);
332 let nodes: Vec<daggy::NodeIndex> = walker.iter(&self.dependencies).collect();
333 target_set.extend(nodes.iter());
334 }
335 }
336
337 Ok(target_set)
338 }
339
340 pub fn up(&mut self, to: Option<I>) -> Result<(), MigratorError<I, T::Error>> {
345 if let Some(to) = &to {
346 info!("Migrating up to target: {}", to);
347 } else {
348 info!("Migrating everything");
349 }
350
351 self.register_edges()?;
353
354 let target_idxs = self
355 .induced_stream(to, EdgeDirection::Incoming)
356 .map_err(MigratorError::Dependency)?;
357
358 let applied_migrations = self.adapter.applied_migrations()?;
361 for idx in target_idxs {
362 let migration = &self.dependencies[idx];
363 let id = migration.id();
364 if applied_migrations.contains(&id) {
365 continue;
366 }
367
368 info!("Applying migration {}", id);
369 self.adapter
370 .apply_migration(migration)
371 .map_err(|e| MigratorError::Migration {
372 id,
373 description: migration.description(),
374 direction: MigrationDirection::Up,
375 error: e,
376 })?;
377 }
378
379 Ok(())
380 }
381
382 pub fn down(&mut self, to: Option<I>) -> Result<(), MigratorError<I, T::Error>> {
388 if let Some(to) = &to {
389 info!("Migrating up to target: {}", to);
390 } else {
391 info!("Migrating everything");
392 }
393
394 self.register_edges()?;
396
397 let mut target_idxs = self
398 .induced_stream(to.clone(), EdgeDirection::Outgoing)
399 .map_err(MigratorError::Dependency)?;
400 if let Some(sink_id) = to {
401 target_idxs.remove(
402 self.id_map
403 .get(&sink_id)
404 .expect("Id is checked in induced_stream and exists"),
405 );
406 }
407
408 let applied_migrations = self.adapter.applied_migrations()?;
409 for idx in target_idxs {
410 let migration = &self.dependencies[idx];
411 let id = migration.id();
412 if !applied_migrations.contains(&id) {
413 continue;
414 }
415
416 info!("Reverting migration {}", id);
417 self.adapter
418 .revert_migration(migration)
419 .map_err(|e| MigratorError::Migration {
420 id,
421 description: migration.description(),
422 direction: MigrationDirection::Down,
423 error: e,
424 })?;
425 }
426
427 Ok(())
428 }
429}
430
431#[cfg(test)]
432pub mod tests {
433 use std::cell::RefCell;
434
435 use super::testing::*;
436 use super::*;
437
438 struct DefaultTestAdapter {
439 applied_migrations: HashSet<usize>,
440 }
441
442 impl DefaultTestAdapter {
443 fn new() -> DefaultTestAdapter {
444 DefaultTestAdapter {
445 applied_migrations: HashSet::new(),
446 }
447 }
448 }
449
450 #[derive(Debug, Error)]
451 #[error("An error occurred.")]
452 struct DefaultTestAdapterError;
453
454 impl Adapter<usize> for DefaultTestAdapter {
455 type MigrationType = TestMigration<usize>;
456
457 type Error = DefaultTestAdapterError;
458
459 fn applied_migrations(&mut self) -> Result<HashSet<usize>, Self::Error> {
460 Ok(self.applied_migrations.clone())
461 }
462
463 fn apply_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
464 self.applied_migrations.insert(migration.id());
465 Ok(())
466 }
467
468 fn revert_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
469 self.applied_migrations.remove(&migration.id());
470 Ok(())
471 }
472 }
473
474 impl TestAdapter<usize> for DefaultTestAdapter {
475 fn mock(id: usize, dependencies: HashSet<usize>) -> Self::MigrationType {
476 TestMigration::new(id, dependencies)
477 }
478 }
479
480 test_schemerz_adapter!(DefaultTestAdapter::new(), 0..);
481
482 pub struct TestMigrationWithCheck {
483 id: usize,
484 dependencies: HashSet<usize>,
485 check_fn_up: Box<dyn Fn()>,
486 check_fn_down: Box<dyn Fn()>,
487 }
488
489 impl TestMigrationWithCheck {
490 pub fn new<Fup: Fn() + 'static, Fdown: Fn() + 'static>(
491 id: usize,
492 dependencies: HashSet<usize>,
493 check_fn_up: Fup,
494 check_fn_down: Fdown,
495 ) -> Self {
496 TestMigrationWithCheck {
497 id,
498 dependencies,
499 check_fn_up: Box::new(check_fn_up),
500 check_fn_down: Box::new(check_fn_down),
501 }
502 }
503 }
504
505 impl Migration<usize> for TestMigrationWithCheck {
506 fn id(&self) -> usize {
507 self.id
508 }
509
510 fn dependencies(&self) -> HashSet<usize> {
511 self.dependencies.clone()
512 }
513
514 fn description(&self) -> &'static str {
515 "Test Migration"
516 }
517 }
518
519 #[derive(Default)]
520 struct TestAdapterWithCheck {
521 applied_migrations: HashSet<usize>,
522 }
523
524 impl Adapter<usize> for TestAdapterWithCheck {
525 type MigrationType = TestMigrationWithCheck;
526
527 type Error = DefaultTestAdapterError;
528
529 fn applied_migrations(&mut self) -> Result<HashSet<usize>, Self::Error> {
530 Ok(self.applied_migrations.clone())
531 }
532
533 fn apply_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
534 self.applied_migrations.insert(migration.id());
535 (migration.check_fn_up)();
536 Ok(())
537 }
538
539 fn revert_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
540 self.applied_migrations.remove(&migration.id());
541 (migration.check_fn_down)();
542 Ok(())
543 }
544 }
545
546 #[test]
547 fn test_migrations_run_order() {
548 let ran_migrations = Rc::new(RefCell::new(HashSet::new()));
549
550 let mut migrator = Migrator::new(TestAdapterWithCheck::default());
551
552 let rm = ran_migrations.clone();
553 let rm2 = ran_migrations.clone();
554 migrator
555 .register(TestMigrationWithCheck::new(
556 1,
557 [].into_iter().collect(),
558 move || {
559 rm.borrow_mut().insert(1);
560 },
561 move || {
562 rm2.borrow_mut().remove(&1);
563 },
564 ))
565 .unwrap();
566
567 for i in 1_usize..4 {
575 for j in 0..2_usize.pow(i as u32) {
576 let id = 2_usize.pow(i as u32) + j;
577 let dep = id / 2;
578 let rm1 = ran_migrations.clone();
579 let rm2 = ran_migrations.clone();
580 migrator
581 .register(TestMigrationWithCheck::new(
582 id,
583 [dep].into_iter().collect(),
584 move || {
585 let mut borrow = rm1.borrow_mut();
586 if !borrow.contains(&dep) {
587 panic!("Up called before dependency id:{}, dep:{}", id, dep)
588 } else {
589 borrow.insert(id);
590 }
591 },
592 move || {
593 let mut borrow = rm2.borrow_mut();
594 if borrow.contains(&(id * 2)) {
595 panic!("Down called before dependant id:{}, dep:{}", id, id * 2)
596 } else if borrow.contains(&(id * 2 + 1)) {
597 panic!("Down called before dependant id:{}, dep:{}", id, id * 2 + 1)
598 } else {
599 borrow.remove(&id);
600 }
601 },
602 ))
603 .unwrap();
604 }
605 }
606
607 migrator.up(None).unwrap();
608 let migrations = ran_migrations.borrow();
609 assert_eq!(migrations.clone(), (1..16).collect::<HashSet<_>>());
610 drop(migrations);
611
612 migrator.down(None).unwrap();
613 let migrations = ran_migrations.borrow();
614 assert_eq!(migrations.clone(), [].iter().cloned().collect());
615 drop(migrations);
616
617 migrator.up(None).unwrap();
618
619 migrator.down(Some(2)).unwrap();
625 let migrations = ran_migrations.borrow();
626 assert_eq!(
627 migrations.clone(),
628 [1, 2, 3, 6, 7, 12, 13, 14, 15].iter().cloned().collect()
629 );
630 drop(migrations);
631
632 migrator.down(Some(3)).unwrap();
638 let migrations = ran_migrations.borrow();
639 assert_eq!(migrations.clone(), [1, 2, 3].iter().cloned().collect());
640 drop(migrations);
641
642 migrator.down(None).unwrap();
643
644 migrator.up(Some(3)).unwrap();
650 let migrations = ran_migrations.borrow();
651 assert_eq!(migrations.clone(), [1, 3].iter().cloned().collect());
652 drop(migrations);
653
654 migrator.up(Some(9)).unwrap();
660 let migrations = ran_migrations.borrow();
661 assert_eq!(
662 migrations.clone(),
663 [1, 2, 3, 4, 9].iter().cloned().collect()
664 );
665 drop(migrations);
666
667 migrator.up(Some(12)).unwrap();
673 let migrations = ran_migrations.borrow();
674 assert_eq!(
675 migrations.clone(),
676 [1, 2, 3, 4, 6, 9, 12].iter().cloned().collect()
677 );
678 drop(migrations);
679
680 migrator.up(None).unwrap();
681 let migrations = ran_migrations.borrow();
682 assert_eq!(migrations.clone(), (1..16).collect::<HashSet<_>>());
683 drop(migrations);
684 }
685}