1#![cfg_attr(feature = "_docs", feature(doc_cfg))]
8#![warn(clippy::pedantic)]
9#![allow(
10 clippy::cast_possible_truncation,
11 clippy::cast_possible_wrap,
12 clippy::cast_sign_loss,
13 clippy::cast_lossless,
14 clippy::unreadable_literal,
15 clippy::doc_markdown,
16 clippy::module_name_repetitions
17)]
18
19use db::{AppliedMigration, Migrations};
20use futures_core::future::LocalBoxFuture;
21use itertools::{EitherOrBoth, Itertools};
22use sha2::{Digest, Sha256};
23use sqlx::{ConnectOptions, Connection, Database, Executor, Pool};
24use state::TypeMap;
25use std::{
26 borrow::Cow,
27 str::FromStr,
28 sync::Arc,
29 time::{Duration, Instant},
30};
31
32pub mod context;
33pub mod db;
34pub mod error;
35
36pub use context::MigrationContext;
37pub use error::Error;
38
39#[cfg(feature = "cli")]
40#[cfg_attr(feature = "_docs", doc(cfg(feature = "cli")))]
41pub mod cli;
42
43#[cfg(feature = "generate")]
44#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
45mod gen;
46
47#[cfg(feature = "generate")]
48#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
49pub use gen::generate;
50
51type MigrationFn<DB> =
52 Box<dyn Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>>>;
53
54pub const DEFAULT_MIGRATIONS_TABLE: &str = "_sqlx_migrations";
56
57pub mod prelude {
59 pub use super::Migration;
60 pub use super::MigrationContext;
61 pub use super::MigrationError;
62 pub use super::MigrationStatus;
63 pub use super::MigrationSummary;
64 pub use super::Migrator;
65 pub use super::MigratorOptions;
66}
67
68pub struct Migration<DB: Database> {
92 name: Cow<'static, str>,
93 up: MigrationFn<DB>,
94 down: Option<MigrationFn<DB>>,
95}
96
97impl<DB: Database> Migration<DB> {
98 pub fn new(
101 name: impl Into<Cow<'static, str>>,
102 up: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
103 ) -> Self {
104 Self {
105 name: name.into(),
106 up: Box::new(up),
107 down: None,
108 }
109 }
110
111 #[must_use]
113 pub fn reversible(
114 mut self,
115 down: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
116 ) -> Self {
117 self.down = Some(Box::new(down));
118 self
119 }
120
121 #[must_use]
123 pub fn revertible(
124 self,
125 down: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
126 ) -> Self {
127 self.reversible(down)
128 }
129
130 #[must_use]
132 pub fn name(&self) -> &str {
133 self.name.as_ref()
134 }
135
136 #[must_use]
138 pub fn is_reversible(&self) -> bool {
139 self.down.is_some()
140 }
141
142 #[must_use]
144 pub fn is_revertible(&self) -> bool {
145 self.down.is_some()
146 }
147}
148
149impl<DB: Database> Eq for Migration<DB> {}
150impl<DB: Database> PartialEq for Migration<DB> {
151 fn eq(&self, other: &Self) -> bool {
152 self.name == other.name
153 }
154}
155
156#[must_use]
203pub struct Migrator<Db>
204where
205 Db: Database,
206 Db::Connection: db::Migrations,
207{
208 options: MigratorOptions,
209 conn: Db::Connection,
210 table: Cow<'static, str>,
211 migrations: Vec<Migration<Db>>,
212 extensions: Arc<TypeMap!(Send + Sync)>,
213}
214
215impl<Db> Migrator<Db>
216where
217 Db: Database,
218 Db::Connection: db::Migrations,
219 for<'a> &'a mut Db::Connection: Executor<'a>,
220{
221 pub fn new(conn: Db::Connection) -> Self {
223 Self {
224 options: MigratorOptions::default(),
225 conn,
226 table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
227 migrations: Vec::default(),
228 extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
229 }
230 }
231
232 pub async fn connect(url: &str) -> Result<Self, sqlx::Error> {
241 let mut opts: <<Db as Database>::Connection as Connection>::Options = url.parse()?;
242 opts = opts.disable_statement_logging();
243
244 let mut conn = Db::Connection::connect_with(&opts).await?;
245 conn.execute(
246 r#"--sql
247 SET client_min_messages TO WARNING;
248 "#,
249 )
250 .await?;
251
252 Ok(Self {
253 options: MigratorOptions::default(),
254 conn,
255 table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
256 migrations: Vec::default(),
257 extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
258 })
259 }
260
261 pub async fn connect_with(
267 options: &<Db::Connection as Connection>::Options,
268 ) -> Result<Self, sqlx::Error> {
269 let mut conn = Db::Connection::connect_with(options).await?;
270 conn.execute(
271 r#"--sql
272 SET client_min_messages TO WARNING;
273 "#,
274 )
275 .await?;
276
277 Ok(Self {
278 options: MigratorOptions::default(),
279 conn,
280 table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
281 migrations: Vec::default(),
282 extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
283 })
284 }
285
286 pub async fn connect_with_pool(pool: &Pool<Db>) -> Result<Self, sqlx::Error> {
294 let mut conn = pool.acquire().await?;
295 conn.execute(
296 r#"--sql
297 SET client_min_messages TO WARNING;
298 "#,
299 )
300 .await?;
301
302 Ok(Self {
303 options: MigratorOptions::default(),
304 conn: conn.detach(),
305 table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
306 migrations: Vec::default(),
307 extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
308 })
309 }
310
311 pub fn set_migrations_table(&mut self, name: impl AsRef<str>) {
315 self.table = Cow::Owned(name.as_ref().to_string());
316 }
317
318 pub fn add_migrations(&mut self, migrations: impl IntoIterator<Item = Migration<Db>>) {
320 self.migrations.extend(migrations);
321 }
322
323 pub fn set_options(&mut self, options: MigratorOptions) {
325 self.options = options;
326 }
327
328 pub fn with<T: Send + Sync + 'static>(&mut self, value: T) -> &mut Self {
330 self.set(value);
331 self
332 }
333
334 pub fn set<T: Send + Sync + 'static>(&mut self, value: T) {
336 self.extensions.set(value);
337 }
338
339 pub fn local_migrations(&self) -> &[Migration<Db>] {
343 &self.migrations
344 }
345}
346
347impl<Db> Migrator<Db>
348where
349 Db: Database,
350 Db::Connection: db::Migrations,
351 for<'a> &'a mut Db::Connection: Executor<'a>,
352{
353 #[allow(clippy::missing_panics_doc)]
363 pub async fn migrate(mut self, target_version: u64) -> Result<MigrationSummary, Error> {
364 self.local_migration(target_version)?;
365 self.conn.ensure_migrations_table(&self.table).await?;
366
367 let db_migrations = self.conn.list_migrations(&self.table).await?;
368
369 self.check_migrations(&db_migrations)?;
370
371 let to_apply = self.migrations.iter();
372
373 let db_version = db_migrations.len() as _;
374
375 let mut conn = self.conn;
376 conn.execute("BEGIN").await?;
377
378 for (idx, mig) in to_apply.enumerate() {
379 let mig_version = idx as u64 + 1;
380
381 if mig_version > target_version {
382 break;
383 }
384
385 if mig_version <= db_version {
386 continue;
387 }
388
389 let start = Instant::now();
390
391 tracing::info!(
392 version = mig_version,
393 name = %mig.name,
394 "applying migration"
395 );
396
397 let hasher = Sha256::new();
398
399 let mut ctx = MigrationContext {
407 hash_only: true,
408 ext: self.extensions.clone(),
409 hasher,
410 conn,
411 };
412
413 (*mig.up)(&mut ctx)
414 .await
415 .map_err(|error| Error::Migration {
416 name: mig.name.clone(),
417 version: mig_version,
418 error,
419 })?;
420
421 let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
422
423 ctx.hash_only = false;
424
425 (*mig.up)(&mut ctx)
426 .await
427 .map_err(|error| Error::Migration {
428 name: mig.name.clone(),
429 version: mig_version,
430 error,
431 })?;
432
433 let execution_time = start.elapsed();
434
435 if self.options.verify_checksums {
436 if let Some(db_mig) = db_migrations.get(idx) {
437 if db_mig.checksum != checksum {
438 ctx.conn.execute("ROLLBACK").await?;
439
440 return Err(Error::ChecksumMismatch {
441 version: mig_version,
442 local_checksum: checksum.clone().into(),
443 db_checksum: db_mig.checksum.clone(),
444 });
445 }
446 }
447 }
448
449 ctx.conn
450 .add_migration(
451 &self.table,
452 AppliedMigration {
453 version: mig_version,
454 name: mig.name.clone(),
455 checksum: checksum.into(),
456 execution_time,
457 },
458 )
459 .await?;
460
461 conn = ctx.conn;
462
463 tracing::info!(
464 version = mig_version,
465 name = %mig.name,
466 execution_time = %humantime::Duration::from(execution_time),
467 "migration applied"
468 );
469 }
470
471 tracing::info!("committing changes");
472 conn.execute("COMMIT").await?;
473
474 Ok(MigrationSummary {
475 old_version: if db_migrations.is_empty() {
476 None
477 } else {
478 Some(db_migrations.len() as _)
479 },
480 new_version: Some(target_version.max(db_version)),
481 })
482 }
483
484 pub async fn migrate_all(self) -> Result<MigrationSummary, Error> {
490 if self.migrations.is_empty() {
491 return Ok(MigrationSummary {
492 new_version: None,
493 old_version: None,
494 });
495 }
496 let migrations = self.migrations.len() as _;
497 self.migrate(migrations).await
498 }
499
500 #[allow(clippy::missing_panics_doc)]
509 pub async fn revert(mut self, target_version: u64) -> Result<MigrationSummary, Error> {
510 self.local_migration(target_version)?;
511 self.conn.ensure_migrations_table(&self.table).await?;
512
513 let db_migrations = self.conn.list_migrations(&self.table).await?;
514
515 self.check_migrations(&db_migrations)?;
516
517 let to_revert = self
518 .migrations
519 .iter()
520 .enumerate()
521 .skip_while(|(idx, _)| idx + 1 < target_version as _)
522 .take_while(|(idx, _)| *idx < db_migrations.len())
523 .collect::<Vec<_>>()
524 .into_iter()
525 .rev();
526
527 let mut conn = self.conn;
528 conn.execute("BEGIN").await?;
529
530 for (idx, mig) in to_revert {
531 let version = idx as u64 + 1;
532
533 let start = Instant::now();
534
535 tracing::info!(
536 version,
537 name = %mig.name,
538 "reverting migration"
539 );
540
541 let hasher = Sha256::new();
542
543 let mut ctx = MigrationContext {
544 hash_only: false,
545 ext: self.extensions.clone(),
546 hasher,
547 conn,
548 };
549
550 match &mig.down {
551 Some(down) => {
552 down(&mut ctx).await.map_err(|error| Error::Revert {
553 name: mig.name.clone(),
554 version,
555 error,
556 })?;
557 }
558 None => {
559 tracing::warn!(
560 version,
561 name = %mig.name,
562 "no down migration found"
563 );
564 }
565 }
566
567 let execution_time = start.elapsed();
568
569 ctx.conn.remove_migration(&self.table, version).await?;
570
571 conn = ctx.conn;
572
573 tracing::info!(
574 version,
575 name = %mig.name,
576 execution_time = %humantime::Duration::from(execution_time),
577 "migration reverted"
578 );
579 }
580
581 tracing::info!("committing changes");
582 conn.execute("COMMIT").await?;
583
584 Ok(MigrationSummary {
585 old_version: if db_migrations.is_empty() {
586 None
587 } else {
588 Some(db_migrations.len() as _)
589 },
590 new_version: if target_version == 1 {
591 None
592 } else {
593 Some(target_version - 1)
594 },
595 })
596 }
597
598 pub async fn revert_all(self) -> Result<MigrationSummary, Error> {
604 self.revert(1).await
605 }
606
607 #[allow(clippy::missing_panics_doc)]
623 pub async fn force_version(mut self, version: u64) -> Result<MigrationSummary, Error> {
624 self.conn.ensure_migrations_table(&self.table).await?;
625
626 let db_migrations = self.conn.list_migrations(&self.table).await?;
627
628 if version == 0 {
629 self.conn.clear_migrations(&self.table).await?;
630 return Ok(MigrationSummary {
631 old_version: if db_migrations.is_empty() {
632 None
633 } else {
634 Some(db_migrations.len() as _)
635 },
636 new_version: None,
637 });
638 }
639
640 self.local_migration(version)?;
641
642 let migrations = self
643 .migrations
644 .iter()
645 .enumerate()
646 .take_while(|(idx, _)| *idx < version as usize);
647
648 self.conn.clear_migrations(&self.table).await?;
649
650 let mut conn = self.conn;
651 conn.execute("BEGIN").await?;
652
653 for (idx, mig) in migrations {
654 let mig_version = idx as u64 + 1;
655
656 let hasher = Sha256::new();
657
658 let mut ctx = MigrationContext {
659 hash_only: true,
660 ext: self.extensions.clone(),
661 hasher,
662 conn,
663 };
664
665 (*mig.up)(&mut ctx)
666 .await
667 .map_err(|error| Error::Migration {
668 name: mig.name.clone(),
669 version: mig_version,
670 error,
671 })?;
672
673 let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
674
675 ctx.conn
676 .add_migration(
677 &self.table,
678 AppliedMigration {
679 version: mig_version,
680 name: mig.name.clone(),
681 checksum: checksum.into(),
682 execution_time: Duration::default(),
683 },
684 )
685 .await?;
686
687 conn = ctx.conn;
688
689 tracing::info!(
690 version = idx + 1,
691 name = %mig.name,
692 "migration forcibly set as applied"
693 );
694 }
695
696 tracing::info!("committing changes");
697 conn.execute("COMMIT").await?;
698
699 Ok(MigrationSummary {
700 old_version: if db_migrations.is_empty() {
701 None
702 } else {
703 Some(db_migrations.len() as _)
704 },
705 new_version: Some(version),
706 })
707 }
708
709 #[allow(clippy::missing_panics_doc)]
723 pub async fn verify(mut self) -> Result<(), Error> {
724 self.conn.ensure_migrations_table(&self.table).await?;
725 let migrations = self.conn.list_migrations(&self.table).await?;
726 self.check_migrations(&migrations)?;
727
728 if self.options.verify_checksums {
729 for res in self.verify_checksums(&migrations).await?.1 {
730 res?;
731 }
732 }
733
734 Ok(())
735 }
736
737 #[allow(clippy::missing_panics_doc)]
744 pub async fn status(mut self) -> Result<Vec<MigrationStatus>, Error> {
745 self.conn.ensure_migrations_table(&self.table).await?;
746
747 let migrations = self.conn.list_migrations(&self.table).await?;
748
749 let mut status = Vec::with_capacity(self.migrations.len());
750
751 let (migrator, checksums) = self.verify_checksums(&migrations).await?;
752 self = migrator;
753
754 for (idx, pair) in self.migrations.iter().zip_longest(migrations).enumerate() {
755 let version = idx as u64 + 1;
756
757 match pair {
758 EitherOrBoth::Both(local, db) => status.push(MigrationStatus {
759 version,
760 name: local.name.clone().into_owned(),
761 reversible: local.is_reversible(),
762 applied: Some(db),
763 missing_local: false,
764 checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
765 }),
766 EitherOrBoth::Left(local) => status.push(MigrationStatus {
767 version,
768 name: local.name.clone().into_owned(),
769 reversible: local.is_reversible(),
770 applied: None,
771 missing_local: false,
772 checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
773 }),
774 EitherOrBoth::Right(r) => status.push(MigrationStatus {
775 version: r.version,
776 name: r.name.clone().into_owned(),
777 reversible: false,
778 applied: Some(r),
779 missing_local: true,
780 checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
781 }),
782 }
783 }
784
785 Ok(status)
786 }
787}
788
789impl<Db> Migrator<Db>
790where
791 Db: Database,
792 Db::Connection: db::Migrations,
793 for<'a> &'a mut Db::Connection: Executor<'a>,
794{
795 fn local_migration(&self, version: u64) -> Result<&Migration<Db>, Error> {
796 if version == 0 {
797 return Err(Error::InvalidVersion {
798 version,
799 min_version: 1,
800 max_version: self.migrations.len() as _,
801 });
802 }
803
804 if self.migrations.is_empty() {
805 return Err(Error::InvalidVersion {
806 version,
807 min_version: 1,
808 max_version: self.migrations.len() as _,
809 });
810 }
811
812 let idx = version - 1;
813
814 self.migrations
815 .get(idx as usize)
816 .ok_or(Error::InvalidVersion {
817 version,
818 min_version: 1,
819 max_version: self.migrations.len() as _,
820 })
821 }
822
823 fn check_migrations(&mut self, migrations: &[AppliedMigration<'_>]) -> Result<(), Error> {
824 if self.migrations.len() < migrations.len() {
825 return Err(Error::MissingMigrations {
826 local_count: self.migrations.len(),
827 db_count: migrations.len(),
828 });
829 }
830
831 for (idx, (db_migration, local_migration)) in
832 migrations.iter().zip(self.migrations.iter()).enumerate()
833 {
834 let version = idx as u64 + 1;
835
836 if self.options.verify_names && db_migration.name != local_migration.name {
837 return Err(Error::NameMismatch {
838 version,
839 local_name: local_migration.name.clone(),
840 db_name: db_migration.name.to_string().into(),
841 });
842 }
843 }
844
845 Ok(())
846 }
847
848 async fn verify_checksums(
849 mut self,
850 migrations: &[AppliedMigration<'_>],
851 ) -> Result<(Self, Vec<Result<(), Error>>), Error> {
852 let mut results = Vec::with_capacity(self.migrations.len());
853
854 let local_migrations = self.migrations.iter();
855
856 let mut conn = self.conn;
857
858 for (idx, mig) in local_migrations.enumerate() {
859 let mig_version = idx as u64 + 1;
860
861 let hasher = Sha256::new();
862
863 let mut ctx = MigrationContext {
864 hash_only: true,
865 ext: self.extensions.clone(),
866 hasher,
867 conn,
868 };
869
870 (*mig.up)(&mut ctx)
871 .await
872 .map_err(|error| Error::Migration {
873 name: mig.name.clone(),
874 version: mig_version,
875 error,
876 })?;
877
878 let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
879 conn = ctx.conn;
880
881 if let Some(db_mig) = migrations.get(idx) {
882 if db_mig.checksum == checksum {
883 results.push(Ok(()));
884 } else {
885 results.push(Err(Error::ChecksumMismatch {
886 version: mig_version,
887 local_checksum: checksum.clone().into(),
888 db_checksum: db_mig.checksum.clone().into_owned().into(),
889 }));
890 }
891 }
892 }
893
894 conn.execute("ROLLBACK").await?;
895 self.conn = conn;
896
897 Ok((self, results))
898 }
899}
900
901#[derive(Debug)]
903pub struct MigratorOptions {
904 pub verify_checksums: bool,
906 pub verify_names: bool,
908}
909
910impl Default for MigratorOptions {
911 fn default() -> Self {
912 Self {
913 verify_checksums: true,
914 verify_names: true,
915 }
916 }
917}
918
919#[derive(Debug, Clone)]
921pub struct MigrationSummary {
922 pub old_version: Option<u64>,
924 pub new_version: Option<u64>,
926}
927
928#[derive(Debug, Clone)]
930pub struct MigrationStatus {
931 pub version: u64,
933 pub name: String,
935 pub reversible: bool,
937 pub applied: Option<db::AppliedMigration<'static>>,
939 pub missing_local: bool,
942 pub checksum_ok: bool,
944}
945
946pub type MigrationError = anyhow::Error;
950
951#[cfg_attr(feature = "cli", derive(clap::ValueEnum))]
953#[derive(Debug, Clone, Copy)]
954#[non_exhaustive]
955pub enum DatabaseType {
956 Postgres,
957 Sqlite,
958 Any,
959}
960
961impl DatabaseType {
962 fn sqlx_type(self) -> &'static str {
963 match self {
964 DatabaseType::Postgres => "Postgres",
965 DatabaseType::Sqlite => "Sqlite",
966 DatabaseType::Any => "Any",
967 }
968 }
969}
970
971impl FromStr for DatabaseType {
972 type Err = anyhow::Error;
973
974 fn from_str(s: &str) -> Result<Self, Self::Err> {
975 match s {
976 "postgres" => Ok(Self::Postgres),
977 "sqlite" => Ok(Self::Sqlite),
978 "any" => Ok(Self::Any),
979 db => Err(anyhow::anyhow!("invalid database type `{}`", db)),
980 }
981 }
982}