schemerz/
lib.rs

1//! A database schema migration library that supports directed acyclic graph
2//! (DAG) dependencies between migrations.
3//!
4//! To use with a specific database, an adapter is required. Known adapter
5//! crates:
6//!
7//! - PostgreSQL: [`schemerz-postgres`](https://crates.io/crates/schemerz-postgres)
8//! - SQLite: [`schemerz-rusqlite`](https://crates.io/crates/schemerz-rusqlite)
9#![warn(clippy::all)]
10#![forbid(unsafe_code)]
11
12use std::collections::{HashMap, HashSet};
13use std::fmt::{Debug, Display};
14use std::hash::Hash;
15use std::rc::Rc;
16use std::sync::Arc;
17
18use daggy::petgraph::EdgeDirection;
19use daggy::{Dag, Walker};
20use indexmap::IndexSet;
21use log::{debug, info};
22use thiserror::Error;
23
24use crate::traversal::DfsPostOrderDirectional;
25
26#[macro_use]
27pub mod testing;
28mod traversal;
29
30/// Metadata for defining the identity and dependence relations of migrations.
31/// Specific adapters require additional traits for actual application and
32/// reversion of migrations.
33pub trait Migration<I> {
34    /// Unique identifier for this migration.
35    fn id(&self) -> I;
36
37    /// Set of IDs of all direct dependencies of this migration.
38    fn dependencies(&self) -> HashSet<I>;
39
40    /// User-targeted description of this migration.
41    fn description(&self) -> &'static str;
42}
43
44impl<I, T> Migration<I> for Box<T>
45where
46    T: Migration<I> + ?Sized,
47{
48    fn id(&self) -> I {
49        self.as_ref().id()
50    }
51
52    fn dependencies(&self) -> HashSet<I> {
53        self.as_ref().dependencies()
54    }
55
56    fn description(&self) -> &'static str {
57        self.as_ref().description()
58    }
59}
60
61impl<I, T> Migration<I> for Rc<T>
62where
63    T: Migration<I> + ?Sized,
64{
65    fn id(&self) -> I {
66        self.as_ref().id()
67    }
68
69    fn dependencies(&self) -> HashSet<I> {
70        self.as_ref().dependencies()
71    }
72
73    fn description(&self) -> &'static str {
74        self.as_ref().description()
75    }
76}
77
78impl<I, T> Migration<I> for Arc<T>
79where
80    T: Migration<I> + ?Sized,
81{
82    fn id(&self) -> I {
83        self.as_ref().id()
84    }
85
86    fn dependencies(&self) -> HashSet<I> {
87        self.as_ref().dependencies()
88    }
89
90    fn description(&self) -> &'static str {
91        self.as_ref().description()
92    }
93}
94
95/// Create a trivial implementation of `Migration` for a type.
96///
97/// ## Example
98///
99/// ```rust
100/// #[macro_use]
101/// extern crate schemerz;
102/// extern crate uuid;
103///
104/// use schemerz::Migration;
105/// use uuid::uuid;
106///
107/// struct ParentMigration;
108/// migration!(
109///     ParentMigration,
110///     uuid!("bc960dc8-0e4a-4182-a62a-8e776d1e2b30"),
111///     [],
112///     "Parent migration in a DAG");
113///
114/// struct ChildMigration;
115/// migration!(
116///     ChildMigration,
117///     uuid!("4885e8ab-dafa-4d76-a565-2dee8b04ef60"),
118///     [uuid!("bc960dc8-0e4a-4182-a62a-8e776d1e2b30")],
119///     "Child migration in a DAG");
120///
121/// fn main() {
122///     let parent = ParentMigration;
123///     let child = ChildMigration;
124///
125///     assert!(child.dependencies().contains(&parent.id()));
126/// }
127/// ```
128#[macro_export]
129macro_rules! migration {
130    ($name:ident, $id:expr, [ $( $dependency_id:expr ),*], $description:expr) => {
131        migration!(::uuid::Uuid, $name, $id, [$($dependency_id),*], $description);
132    };
133    ($ty:path, $name:ident, $id:expr, [ $( $dependency_id:expr ),*], $description:expr) => {
134        impl $crate::Migration<$ty> for $name
135        {
136            fn id(&self) -> $ty {
137                $id
138            }
139
140            fn dependencies(&self) -> ::std::collections::HashSet<$ty> {
141                ::std::collections::HashSet::from([
142                    $(
143                        $dependency_id,
144                    )*
145                ])
146            }
147
148            fn description(&self) -> &'static str {
149                $description
150            }
151        }
152    };
153}
154
155/// Direction in which a migration is applied (`Up`) or reverted (`Down`).
156#[derive(Debug)]
157pub enum MigrationDirection {
158    Up,
159    Down,
160}
161
162impl Display for MigrationDirection {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        let printable = match *self {
165            MigrationDirection::Up => "up",
166            MigrationDirection::Down => "Down",
167        };
168        write!(f, "{}", printable)
169    }
170}
171
172/// Trait necessary to adapt schemerz's migration management to a stateful
173/// backend.
174pub trait Adapter<I> {
175    /// Type migrations must implement for this adapter.
176    type MigrationType: Migration<I>;
177
178    /// Type of errors returned by this adapter.
179    type Error: std::error::Error + 'static;
180
181    /// Returns the set of IDs for migrations that have been applied.
182    fn applied_migrations(&mut self) -> Result<HashSet<I>, Self::Error>;
183
184    /// Apply a single migration.
185    fn apply_migration(&mut self, _: &Self::MigrationType) -> Result<(), Self::Error>;
186
187    /// Revert a single migration.
188    fn revert_migration(&mut self, _: &Self::MigrationType) -> Result<(), Self::Error>;
189}
190
191/// Error resulting from the definition of migration identity and dependency.
192#[derive(Debug, Error)]
193pub enum DependencyError<I> {
194    #[error("Duplicate migration ID {0}")]
195    DuplicateId(I),
196    #[error("Unknown migration ID {0}")]
197    UnknownId(I),
198    #[error("Cyclic dependency caused by edge from migration IDs {from} to {to}")]
199    Cycle { from: I, to: I },
200}
201
202/// Error resulting either from migration definitions or from migration
203/// application with an adapter.
204#[derive(Debug, Error)]
205pub enum MigratorError<I, T: std::error::Error + 'static> {
206    #[error("An error occurred due to migration dependencies")]
207    Dependency(#[source] DependencyError<I>),
208    #[error("An error occurred while interacting with the adapter.")]
209    Adapter(#[from] T),
210    #[error(
211        "An error occurred while applying migration {id} ({description}) {direction}: {error}."
212    )]
213    Migration {
214        id: I,
215        description: &'static str,
216        direction: MigrationDirection,
217        #[source]
218        error: T,
219    },
220}
221
222/// Primary schemerz type for defining and applying migrations.
223pub struct Migrator<I, T: Adapter<I>> {
224    adapter: T,
225    dependencies: Dag<T::MigrationType, ()>,
226    id_map: HashMap<I, daggy::NodeIndex>,
227}
228
229impl<I, T> Migrator<I, T>
230where
231    I: Hash + Display + Eq + Clone,
232    T: Adapter<I>,
233{
234    /// Create a `Migrator` using the given `Adapter`.
235    pub fn new(adapter: T) -> Migrator<I, T> {
236        Migrator {
237            adapter,
238            dependencies: Dag::new(),
239            id_map: HashMap::new(),
240        }
241    }
242
243    /// Register a migration into the dependency graph.
244    pub fn register(
245        &mut self,
246        migration: T::MigrationType,
247    ) -> Result<(), MigratorError<I, T::Error>> {
248        let id = migration.id();
249        debug!("Registering migration {}", id);
250        if self.id_map.contains_key(&id) {
251            return Err(MigratorError::Dependency(DependencyError::DuplicateId(id)));
252        }
253
254        let migration_idx = self.dependencies.add_node(migration);
255        self.id_map.insert(id, migration_idx);
256
257        Ok(())
258    }
259
260    /// Register multiple migrations into the dependency graph.
261    pub fn register_multiple(
262        &mut self,
263        migrations: impl Iterator<Item = T::MigrationType>,
264    ) -> Result<(), MigratorError<I, T::Error>> {
265        for migration in migrations {
266            let id = migration.id();
267            debug!("Registering migration (with multiple) {}", id);
268            if self.id_map.contains_key(&id) {
269                return Err(MigratorError::Dependency(DependencyError::DuplicateId(id)));
270            }
271
272            let migration_idx = self.dependencies.add_node(migration);
273            self.id_map.insert(id, migration_idx);
274        }
275
276        Ok(())
277    }
278
279    /// Creates the edges for the current migrations into the dependency graph.
280    fn register_edges(&mut self) -> Result<(), MigratorError<I, T::Error>> {
281        for (id, migration_idx) in self.id_map.iter() {
282            let depends = self
283                .dependencies
284                .node_weight(*migration_idx)
285                .expect("We registered these nodes")
286                .dependencies();
287
288            for d in depends {
289                let parent_idx = self.id_map.get(&d).ok_or_else(|| {
290                    MigratorError::Dependency(DependencyError::UnknownId(d.clone()))
291                })?;
292                self.dependencies
293                    .add_edge(*parent_idx, *migration_idx, ())
294                    .map_err(|_| {
295                        MigratorError::Dependency(DependencyError::Cycle {
296                            from: d,
297                            to: id.clone(),
298                        })
299                    })?;
300            }
301        }
302        Ok(())
303    }
304
305    /// Collect the ids of recursively dependent migrations in `dir` induced
306    /// starting from `id`. If `dir` is `Incoming`, this is all ancestors
307    /// (dependencies); if `Outgoing`, this is all descendents (dependents).
308    /// If `id` is `None`, this is all migrations starting from the sources or
309    /// the sinks, respectively.
310    fn induced_stream(
311        &self,
312        id: Option<I>,
313        dir: EdgeDirection,
314    ) -> Result<IndexSet<daggy::NodeIndex>, DependencyError<I>> {
315        let mut to_visit = Vec::new();
316        match id {
317            Some(id) => {
318                if let Some(id) = self.id_map.get(&id) {
319                    to_visit.push(*id);
320                } else {
321                    return Err(DependencyError::UnknownId(id));
322                }
323            }
324            None => to_visit.extend(self.dependencies.graph().externals(dir.opposite())),
325        }
326
327        let mut target_set = IndexSet::new();
328
329        for idx in to_visit {
330            if !target_set.contains(&idx) {
331                let walker = DfsPostOrderDirectional::new(dir, &self.dependencies, idx);
332                let nodes: Vec<daggy::NodeIndex> = walker.iter(&self.dependencies).collect();
333                target_set.extend(nodes.iter());
334            }
335        }
336
337        Ok(target_set)
338    }
339
340    /// Apply migrations as necessary to so that the specified migration is
341    /// applied (inclusive).
342    ///
343    /// If `to` is `None`, apply all registered migrations.
344    pub fn up(&mut self, to: Option<I>) -> Result<(), MigratorError<I, T::Error>> {
345        if let Some(to) = &to {
346            info!("Migrating up to target: {}", to);
347        } else {
348            info!("Migrating everything");
349        }
350
351        // Register the edges
352        self.register_edges()?;
353
354        let target_idxs = self
355            .induced_stream(to, EdgeDirection::Incoming)
356            .map_err(MigratorError::Dependency)?;
357
358        // TODO: This is assuming the applied_migrations state is consistent
359        // with the dependency graph.
360        let applied_migrations = self.adapter.applied_migrations()?;
361        for idx in target_idxs {
362            let migration = &self.dependencies[idx];
363            let id = migration.id();
364            if applied_migrations.contains(&id) {
365                continue;
366            }
367
368            info!("Applying migration {}", id);
369            self.adapter
370                .apply_migration(migration)
371                .map_err(|e| MigratorError::Migration {
372                    id,
373                    description: migration.description(),
374                    direction: MigrationDirection::Up,
375                    error: e,
376                })?;
377        }
378
379        Ok(())
380    }
381
382    /// Revert migrations as necessary so that no migrations dependent on the
383    /// specified migration are applied. If the specified migration was already
384    /// applied, it will still be applied.
385    ///
386    /// If `to` is `None`, revert all applied migrations.
387    pub fn down(&mut self, to: Option<I>) -> Result<(), MigratorError<I, T::Error>> {
388        if let Some(to) = &to {
389            info!("Migrating up to target: {}", to);
390        } else {
391            info!("Migrating everything");
392        }
393
394        // Register the edges
395        self.register_edges()?;
396
397        let mut target_idxs = self
398            .induced_stream(to.clone(), EdgeDirection::Outgoing)
399            .map_err(MigratorError::Dependency)?;
400        if let Some(sink_id) = to {
401            target_idxs.remove(
402                self.id_map
403                    .get(&sink_id)
404                    .expect("Id is checked in induced_stream and exists"),
405            );
406        }
407
408        let applied_migrations = self.adapter.applied_migrations()?;
409        for idx in target_idxs {
410            let migration = &self.dependencies[idx];
411            let id = migration.id();
412            if !applied_migrations.contains(&id) {
413                continue;
414            }
415
416            info!("Reverting migration {}", id);
417            self.adapter
418                .revert_migration(migration)
419                .map_err(|e| MigratorError::Migration {
420                    id,
421                    description: migration.description(),
422                    direction: MigrationDirection::Down,
423                    error: e,
424                })?;
425        }
426
427        Ok(())
428    }
429}
430
431#[cfg(test)]
432pub mod tests {
433    use std::cell::RefCell;
434
435    use super::testing::*;
436    use super::*;
437
438    struct DefaultTestAdapter {
439        applied_migrations: HashSet<usize>,
440    }
441
442    impl DefaultTestAdapter {
443        fn new() -> DefaultTestAdapter {
444            DefaultTestAdapter {
445                applied_migrations: HashSet::new(),
446            }
447        }
448    }
449
450    #[derive(Debug, Error)]
451    #[error("An error occurred.")]
452    struct DefaultTestAdapterError;
453
454    impl Adapter<usize> for DefaultTestAdapter {
455        type MigrationType = TestMigration<usize>;
456
457        type Error = DefaultTestAdapterError;
458
459        fn applied_migrations(&mut self) -> Result<HashSet<usize>, Self::Error> {
460            Ok(self.applied_migrations.clone())
461        }
462
463        fn apply_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
464            self.applied_migrations.insert(migration.id());
465            Ok(())
466        }
467
468        fn revert_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
469            self.applied_migrations.remove(&migration.id());
470            Ok(())
471        }
472    }
473
474    impl TestAdapter<usize> for DefaultTestAdapter {
475        fn mock(id: usize, dependencies: HashSet<usize>) -> Self::MigrationType {
476            TestMigration::new(id, dependencies)
477        }
478    }
479
480    test_schemerz_adapter!(DefaultTestAdapter::new(), 0..);
481
482    pub struct TestMigrationWithCheck {
483        id: usize,
484        dependencies: HashSet<usize>,
485        check_fn_up: Box<dyn Fn()>,
486        check_fn_down: Box<dyn Fn()>,
487    }
488
489    impl TestMigrationWithCheck {
490        pub fn new<Fup: Fn() + 'static, Fdown: Fn() + 'static>(
491            id: usize,
492            dependencies: HashSet<usize>,
493            check_fn_up: Fup,
494            check_fn_down: Fdown,
495        ) -> Self {
496            TestMigrationWithCheck {
497                id,
498                dependencies,
499                check_fn_up: Box::new(check_fn_up),
500                check_fn_down: Box::new(check_fn_down),
501            }
502        }
503    }
504
505    impl Migration<usize> for TestMigrationWithCheck {
506        fn id(&self) -> usize {
507            self.id
508        }
509
510        fn dependencies(&self) -> HashSet<usize> {
511            self.dependencies.clone()
512        }
513
514        fn description(&self) -> &'static str {
515            "Test Migration"
516        }
517    }
518
519    #[derive(Default)]
520    struct TestAdapterWithCheck {
521        applied_migrations: HashSet<usize>,
522    }
523
524    impl Adapter<usize> for TestAdapterWithCheck {
525        type MigrationType = TestMigrationWithCheck;
526
527        type Error = DefaultTestAdapterError;
528
529        fn applied_migrations(&mut self) -> Result<HashSet<usize>, Self::Error> {
530            Ok(self.applied_migrations.clone())
531        }
532
533        fn apply_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
534            self.applied_migrations.insert(migration.id());
535            (migration.check_fn_up)();
536            Ok(())
537        }
538
539        fn revert_migration(&mut self, migration: &Self::MigrationType) -> Result<(), Self::Error> {
540            self.applied_migrations.remove(&migration.id());
541            (migration.check_fn_down)();
542            Ok(())
543        }
544    }
545
546    #[test]
547    fn test_migrations_run_order() {
548        let ran_migrations = Rc::new(RefCell::new(HashSet::new()));
549
550        let mut migrator = Migrator::new(TestAdapterWithCheck::default());
551
552        let rm = ran_migrations.clone();
553        let rm2 = ran_migrations.clone();
554        migrator
555            .register(TestMigrationWithCheck::new(
556                1,
557                [].into_iter().collect(),
558                move || {
559                    rm.borrow_mut().insert(1);
560                },
561                move || {
562                    rm2.borrow_mut().remove(&1);
563                },
564            ))
565            .unwrap();
566
567        // Making a binary tree with checks to make sure run order is correct
568        // Each node checks if their parents are up before them
569        // and their children are down before them
570        //             1
571        //      2              3
572        //  4      5       6       7
573        // 8 9   10 11   12 13   14 15
574        for i in 1_usize..4 {
575            for j in 0..2_usize.pow(i as u32) {
576                let id = 2_usize.pow(i as u32) + j;
577                let dep = id / 2;
578                let rm1 = ran_migrations.clone();
579                let rm2 = ran_migrations.clone();
580                migrator
581                    .register(TestMigrationWithCheck::new(
582                        id,
583                        [dep].into_iter().collect(),
584                        move || {
585                            let mut borrow = rm1.borrow_mut();
586                            if !borrow.contains(&dep) {
587                                panic!("Up called before dependency id:{}, dep:{}", id, dep)
588                            } else {
589                                borrow.insert(id);
590                            }
591                        },
592                        move || {
593                            let mut borrow = rm2.borrow_mut();
594                            if borrow.contains(&(id * 2)) {
595                                panic!("Down called before dependant id:{}, dep:{}", id, id * 2)
596                            } else if borrow.contains(&(id * 2 + 1)) {
597                                panic!("Down called before dependant id:{}, dep:{}", id, id * 2 + 1)
598                            } else {
599                                borrow.remove(&id);
600                            }
601                        },
602                    ))
603                    .unwrap();
604            }
605        }
606
607        migrator.up(None).unwrap();
608        let migrations = ran_migrations.borrow();
609        assert_eq!(migrations.clone(), (1..16).collect::<HashSet<_>>());
610        drop(migrations);
611
612        migrator.down(None).unwrap();
613        let migrations = ran_migrations.borrow();
614        assert_eq!(migrations.clone(), [].iter().cloned().collect());
615        drop(migrations);
616
617        migrator.up(None).unwrap();
618
619        // Should result in
620        //            1
621        //     2              3
622        //                6       7
623        //              12 13   14 15
624        migrator.down(Some(2)).unwrap();
625        let migrations = ran_migrations.borrow();
626        assert_eq!(
627            migrations.clone(),
628            [1, 2, 3, 6, 7, 12, 13, 14, 15].iter().cloned().collect()
629        );
630        drop(migrations);
631
632        // Should result in
633        //            1
634        //     2              3
635        //
636        //
637        migrator.down(Some(3)).unwrap();
638        let migrations = ran_migrations.borrow();
639        assert_eq!(migrations.clone(), [1, 2, 3].iter().cloned().collect());
640        drop(migrations);
641
642        migrator.down(None).unwrap();
643
644        // Should result in
645        //            1
646        //                    3
647        //
648        //
649        migrator.up(Some(3)).unwrap();
650        let migrations = ran_migrations.borrow();
651        assert_eq!(migrations.clone(), [1, 3].iter().cloned().collect());
652        drop(migrations);
653
654        // Should result in
655        //            1
656        //     2              3
657        //  4
658        //   9
659        migrator.up(Some(9)).unwrap();
660        let migrations = ran_migrations.borrow();
661        assert_eq!(
662            migrations.clone(),
663            [1, 2, 3, 4, 9].iter().cloned().collect()
664        );
665        drop(migrations);
666
667        // Should result in
668        //            1
669        //     2              3
670        //  4              6
671        //   9           12
672        migrator.up(Some(12)).unwrap();
673        let migrations = ran_migrations.borrow();
674        assert_eq!(
675            migrations.clone(),
676            [1, 2, 3, 4, 6, 9, 12].iter().cloned().collect()
677        );
678        drop(migrations);
679
680        migrator.up(None).unwrap();
681        let migrations = ran_migrations.borrow();
682        assert_eq!(migrations.clone(), (1..16).collect::<HashSet<_>>());
683        drop(migrations);
684    }
685}