sea_orm_migration/
migrator.rs

1use std::collections::HashSet;
2use std::fmt::Display;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6use tracing::info;
7
8use sea_orm::sea_query::{
9    self, Alias, Expr, ExprTrait, ForeignKey, IntoIden, JoinType, Order, Query, SelectStatement,
10    SimpleExpr, Table, extension::postgres::Type,
11};
12use sea_orm::{
13    ActiveModelTrait, ActiveValue, Condition, ConnectionTrait, DbBackend, DbErr, DeriveIden,
14    DynIden, EntityTrait, FromQueryResult, Iterable, QueryFilter, Schema, Statement,
15    TransactionTrait,
16};
17#[allow(unused_imports)]
18use sea_schema::probe::SchemaProbe;
19
20use super::{IntoSchemaManagerConnection, MigrationTrait, SchemaManager, seaql_migrations};
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23/// Status of migration
24pub enum MigrationStatus {
25    /// Not yet applied
26    Pending,
27    /// Applied
28    Applied,
29}
30
31impl Display for MigrationStatus {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        let status = match self {
34            MigrationStatus::Pending => "Pending",
35            MigrationStatus::Applied => "Applied",
36        };
37        write!(f, "{status}")
38    }
39}
40
41pub struct Migration {
42    migration: Box<dyn MigrationTrait>,
43    status: MigrationStatus,
44}
45
46impl Migration {
47    /// Get migration name from MigrationName trait implementation
48    pub fn name(&self) -> &str {
49        self.migration.name()
50    }
51
52    /// Get migration status
53    pub fn status(&self) -> MigrationStatus {
54        self.status
55    }
56}
57
58/// Performing migrations on a database
59#[async_trait::async_trait]
60pub trait MigratorTrait: Send {
61    /// Vector of migrations in time sequence
62    fn migrations() -> Vec<Box<dyn MigrationTrait>>;
63
64    /// Name of the migration table, it is `seaql_migrations` by default
65    fn migration_table_name() -> DynIden {
66        seaql_migrations::Entity.into_iden()
67    }
68
69    /// Get list of migrations wrapped in `Migration` struct
70    fn get_migration_files() -> Vec<Migration> {
71        Self::migrations()
72            .into_iter()
73            .map(|migration| Migration {
74                migration,
75                status: MigrationStatus::Pending,
76            })
77            .collect()
78    }
79
80    /// Get list of applied migrations from database
81    async fn get_migration_models<C>(db: &C) -> Result<Vec<seaql_migrations::Model>, DbErr>
82    where
83        C: ConnectionTrait,
84    {
85        Self::install(db).await?;
86        let stmt = Query::select()
87            .table_name(Self::migration_table_name())
88            .columns(seaql_migrations::Column::iter().map(IntoIden::into_iden))
89            .order_by(seaql_migrations::Column::Version, Order::Asc)
90            .to_owned();
91        let builder = db.get_database_backend();
92        seaql_migrations::Model::find_by_statement(builder.build(&stmt))
93            .all(db)
94            .await
95    }
96
97    /// Get list of migrations with status
98    async fn get_migration_with_status<C>(db: &C) -> Result<Vec<Migration>, DbErr>
99    where
100        C: ConnectionTrait,
101    {
102        Self::install(db).await?;
103        let mut migration_files = Self::get_migration_files();
104        let migration_models = Self::get_migration_models(db).await?;
105
106        let migration_in_db: HashSet<String> = migration_models
107            .into_iter()
108            .map(|model| model.version)
109            .collect();
110        let migration_in_fs: HashSet<String> = migration_files
111            .iter()
112            .map(|file| file.migration.name().to_string())
113            .collect();
114
115        let pending_migrations = &migration_in_fs - &migration_in_db;
116        for migration_file in migration_files.iter_mut() {
117            if !pending_migrations.contains(migration_file.migration.name()) {
118                migration_file.status = MigrationStatus::Applied;
119            }
120        }
121
122        let missing_migrations_in_fs = &migration_in_db - &migration_in_fs;
123        let errors: Vec<String> = missing_migrations_in_fs
124            .iter()
125            .map(|missing_migration| {
126                format!("Migration file of version '{missing_migration}' is missing, this migration has been applied but its file is missing")
127            }).collect();
128
129        if !errors.is_empty() {
130            Err(DbErr::Custom(errors.join("\n")))
131        } else {
132            Ok(migration_files)
133        }
134    }
135
136    /// Get list of pending migrations
137    async fn get_pending_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
138    where
139        C: ConnectionTrait,
140    {
141        Self::install(db).await?;
142        Ok(Self::get_migration_with_status(db)
143            .await?
144            .into_iter()
145            .filter(|file| file.status == MigrationStatus::Pending)
146            .collect())
147    }
148
149    /// Get list of applied migrations
150    async fn get_applied_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
151    where
152        C: ConnectionTrait,
153    {
154        Self::install(db).await?;
155        Ok(Self::get_migration_with_status(db)
156            .await?
157            .into_iter()
158            .filter(|file| file.status == MigrationStatus::Applied)
159            .collect())
160    }
161
162    /// Create migration table `seaql_migrations` in the database
163    async fn install<C>(db: &C) -> Result<(), DbErr>
164    where
165        C: ConnectionTrait,
166    {
167        let builder = db.get_database_backend();
168        let table_name = Self::migration_table_name();
169        let schema = Schema::new(builder);
170        let mut stmt = schema
171            .create_table_from_entity(seaql_migrations::Entity)
172            .table_name(table_name);
173        stmt.if_not_exists();
174        db.execute(&stmt).await.map(|_| ())
175    }
176
177    /// Check the status of all migrations
178    async fn status<C>(db: &C) -> Result<(), DbErr>
179    where
180        C: ConnectionTrait,
181    {
182        Self::install(db).await?;
183
184        info!("Checking migration status");
185
186        for Migration { migration, status } in Self::get_migration_with_status(db).await? {
187            info!("Migration '{}'... {}", migration.name(), status);
188        }
189
190        Ok(())
191    }
192
193    /// Drop all tables from the database, then reapply all migrations
194    async fn fresh<'c, C>(db: C) -> Result<(), DbErr>
195    where
196        C: IntoSchemaManagerConnection<'c>,
197    {
198        exec_with_connection::<'_, _, _>(db, move |manager| {
199            Box::pin(async move { exec_fresh::<Self>(manager).await })
200        })
201        .await
202    }
203
204    /// Rollback all applied migrations, then reapply all migrations
205    async fn refresh<'c, C>(db: C) -> Result<(), DbErr>
206    where
207        C: IntoSchemaManagerConnection<'c>,
208    {
209        exec_with_connection::<'_, _, _>(db, move |manager| {
210            Box::pin(async move {
211                exec_down::<Self>(manager, None).await?;
212                exec_up::<Self>(manager, None).await
213            })
214        })
215        .await
216    }
217
218    /// Rollback all applied migrations
219    async fn reset<'c, C>(db: C) -> Result<(), DbErr>
220    where
221        C: IntoSchemaManagerConnection<'c>,
222    {
223        exec_with_connection::<'_, _, _>(db, move |manager| {
224            Box::pin(async move {
225                // Rollback all applied migrations first
226                exec_down::<Self>(manager, None).await?;
227
228                // Then drop the migration table itself
229                let mut stmt = Table::drop();
230                stmt.table(Self::migration_table_name())
231                    .if_exists()
232                    .cascade();
233                manager.drop_table(stmt).await?;
234
235                Ok(())
236            })
237        })
238        .await
239    }
240
241    /// Uninstall migration tracking table only (non-destructive)
242    /// This will drop the `seaql_migrations` table but won't rollback other schema changes.
243    async fn uninstall<'c, C>(db: C) -> Result<(), DbErr>
244    where
245        C: IntoSchemaManagerConnection<'c>,
246    {
247        exec_with_connection::<'_, _, _>(db, move |manager| {
248            Box::pin(async move {
249                let mut stmt = Table::drop();
250                stmt.table(Self::migration_table_name())
251                    .if_exists()
252                    .cascade();
253                manager.drop_table(stmt).await?;
254                Ok(())
255            })
256        })
257        .await
258    }
259
260    /// Apply pending migrations
261    async fn up<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
262    where
263        C: IntoSchemaManagerConnection<'c>,
264    {
265        exec_with_connection::<'_, _, _>(db, move |manager| {
266            Box::pin(async move { exec_up::<Self>(manager, steps).await })
267        })
268        .await
269    }
270
271    /// Rollback applied migrations
272    async fn down<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
273    where
274        C: IntoSchemaManagerConnection<'c>,
275    {
276        exec_with_connection::<'_, _, _>(db, move |manager| {
277            Box::pin(async move { exec_down::<Self>(manager, steps).await })
278        })
279        .await
280    }
281}
282
283async fn exec_with_connection<'c, C, F>(db: C, f: F) -> Result<(), DbErr>
284where
285    C: IntoSchemaManagerConnection<'c>,
286    F: for<'b> Fn(
287        &'b SchemaManager<'_>,
288    ) -> Pin<Box<dyn Future<Output = Result<(), DbErr>> + Send + 'b>>,
289{
290    let db = db.into_schema_manager_connection();
291
292    match db.get_database_backend() {
293        DbBackend::Postgres => {
294            let transaction = db.begin().await?;
295            let manager = SchemaManager::new(&transaction);
296            f(&manager).await?;
297            transaction.commit().await
298        }
299        DbBackend::MySql | DbBackend::Sqlite => {
300            let manager = SchemaManager::new(db);
301            f(&manager).await
302        }
303        db => Err(DbErr::BackendNotSupported {
304            db: db.as_str(),
305            ctx: "exec_with_connection",
306        }),
307    }
308}
309
310async fn exec_fresh<M>(manager: &SchemaManager<'_>) -> Result<(), DbErr>
311where
312    M: MigratorTrait + ?Sized,
313{
314    let db = manager.get_connection();
315
316    M::install(db).await?;
317    let db_backend = db.get_database_backend();
318
319    // Temporarily disable the foreign key check
320    if db_backend == DbBackend::Sqlite {
321        info!("Disabling foreign key check");
322        db.execute_raw(Statement::from_string(
323            db_backend,
324            "PRAGMA foreign_keys = OFF".to_owned(),
325        ))
326        .await?;
327        info!("Foreign key check disabled");
328    }
329
330    // Drop all foreign keys
331    if db_backend == DbBackend::MySql {
332        info!("Dropping all foreign keys");
333        let stmt = query_mysql_foreign_keys(db);
334        let rows = db.query_all(&stmt).await?;
335        for row in rows.into_iter() {
336            let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?;
337            let table_name: String = row.try_get("", "TABLE_NAME")?;
338            info!(
339                "Dropping foreign key '{}' from table '{}'",
340                constraint_name, table_name
341            );
342            let mut stmt = ForeignKey::drop();
343            stmt.table(Alias::new(table_name.as_str()))
344                .name(constraint_name.as_str());
345            db.execute(&stmt).await?;
346            info!("Foreign key '{}' has been dropped", constraint_name);
347        }
348        info!("All foreign keys dropped");
349    }
350
351    // Drop all tables
352    let stmt = query_tables(db)?;
353    let rows = db.query_all(&stmt).await?;
354    for row in rows.into_iter() {
355        let table_name: String = row.try_get("", "table_name")?;
356        info!("Dropping table '{}'", table_name);
357        let mut stmt = Table::drop();
358        stmt.table(Alias::new(table_name.as_str()))
359            .if_exists()
360            .cascade();
361        db.execute(&stmt).await?;
362        info!("Table '{}' has been dropped", table_name);
363    }
364
365    // Drop all types
366    if db_backend == DbBackend::Postgres {
367        info!("Dropping all types");
368        let stmt = query_pg_types(db);
369        let rows = db.query_all(&stmt).await?;
370        for row in rows {
371            let type_name: String = row.try_get("", "typname")?;
372            info!("Dropping type '{}'", type_name);
373            let mut stmt = Type::drop();
374            stmt.name(Alias::new(&type_name));
375            db.execute(&stmt).await?;
376            info!("Type '{}' has been dropped", type_name);
377        }
378    }
379
380    // Restore the foreign key check
381    if db_backend == DbBackend::Sqlite {
382        info!("Restoring foreign key check");
383        db.execute_raw(Statement::from_string(
384            db_backend,
385            "PRAGMA foreign_keys = ON".to_owned(),
386        ))
387        .await?;
388        info!("Foreign key check restored");
389    }
390
391    // Reapply all migrations
392    exec_up::<M>(manager, None).await
393}
394
395async fn exec_up<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
396where
397    M: MigratorTrait + ?Sized,
398{
399    let db = manager.get_connection();
400
401    M::install(db).await?;
402
403    if let Some(steps) = steps {
404        info!("Applying {} pending migrations", steps);
405    } else {
406        info!("Applying all pending migrations");
407    }
408
409    let migrations = M::get_pending_migrations(db).await?.into_iter();
410    if migrations.len() == 0 {
411        info!("No pending migrations");
412    }
413    for Migration { migration, .. } in migrations {
414        if let Some(steps) = steps.as_mut() {
415            if steps == &0 {
416                break;
417            }
418            *steps -= 1;
419        }
420        info!("Applying migration '{}'", migration.name());
421        migration.up(manager).await?;
422        info!("Migration '{}' has been applied", migration.name());
423        let now = SystemTime::now()
424            .duration_since(SystemTime::UNIX_EPOCH)
425            .expect("SystemTime before UNIX EPOCH!");
426        seaql_migrations::Entity::insert(seaql_migrations::ActiveModel {
427            version: ActiveValue::Set(migration.name().to_owned()),
428            applied_at: ActiveValue::Set(now.as_secs() as i64),
429        })
430        .table_name(M::migration_table_name())
431        .exec(db)
432        .await?;
433    }
434
435    Ok(())
436}
437
438async fn exec_down<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
439where
440    M: MigratorTrait + ?Sized,
441{
442    let db = manager.get_connection();
443
444    M::install(db).await?;
445
446    if let Some(steps) = steps {
447        info!("Rolling back {} applied migrations", steps);
448    } else {
449        info!("Rolling back all applied migrations");
450    }
451
452    let migrations = M::get_applied_migrations(db).await?.into_iter().rev();
453    if migrations.len() == 0 {
454        info!("No applied migrations");
455    }
456    for Migration { migration, .. } in migrations {
457        if let Some(steps) = steps.as_mut() {
458            if steps == &0 {
459                break;
460            }
461            *steps -= 1;
462        }
463        info!("Rolling back migration '{}'", migration.name());
464        migration.down(manager).await?;
465        info!("Migration '{}' has been rollbacked", migration.name());
466        seaql_migrations::Entity::delete_many()
467            .filter(Expr::col(seaql_migrations::Column::Version).eq(migration.name()))
468            .table_name(M::migration_table_name())
469            .exec(db)
470            .await?;
471    }
472
473    Ok(())
474}
475
476fn query_tables<C>(db: &C) -> Result<SelectStatement, DbErr>
477where
478    C: ConnectionTrait,
479{
480    match db.get_database_backend() {
481        #[cfg(feature = "sqlx-mysql")]
482        DbBackend::MySql => Ok(sea_schema::mysql::MySql.query_tables()),
483        #[cfg(feature = "sqlx-postgres")]
484        DbBackend::Postgres => Ok(sea_schema::postgres::Postgres.query_tables()),
485        #[cfg(feature = "sqlx-sqlite")]
486        DbBackend::Sqlite => Ok(sea_schema::sqlite::Sqlite.query_tables()),
487        #[allow(unreachable_patterns)]
488        other => Err(DbErr::BackendNotSupported {
489            db: other.as_str(),
490            ctx: "query_tables",
491        }),
492    }
493}
494
495// this function is only called after checking db backend, the panic is unreachable
496fn get_current_schema<C>(db: &C) -> SimpleExpr
497where
498    C: ConnectionTrait,
499{
500    match db.get_database_backend() {
501        #[cfg(feature = "sqlx-mysql")]
502        DbBackend::MySql => sea_schema::mysql::MySql::get_current_schema(),
503        #[cfg(feature = "sqlx-postgres")]
504        DbBackend::Postgres => sea_schema::postgres::Postgres::get_current_schema(),
505        #[cfg(feature = "sqlx-sqlite")]
506        DbBackend::Sqlite => sea_schema::sqlite::Sqlite::get_current_schema(),
507        #[allow(unreachable_patterns)]
508        other => panic!("{other:?} feature is off"),
509    }
510}
511
512#[derive(DeriveIden)]
513enum InformationSchema {
514    #[sea_orm(iden = "information_schema")]
515    Schema,
516    #[sea_orm(iden = "TABLE_NAME")]
517    TableName,
518    #[sea_orm(iden = "CONSTRAINT_NAME")]
519    ConstraintName,
520    TableConstraints,
521    TableSchema,
522    ConstraintType,
523}
524
525fn query_mysql_foreign_keys<C>(db: &C) -> SelectStatement
526where
527    C: ConnectionTrait,
528{
529    let mut stmt = Query::select();
530    stmt.columns([
531        InformationSchema::TableName,
532        InformationSchema::ConstraintName,
533    ])
534    .from((
535        InformationSchema::Schema,
536        InformationSchema::TableConstraints,
537    ))
538    .cond_where(
539        Condition::all()
540            .add(get_current_schema(db).equals((
541                InformationSchema::TableConstraints,
542                InformationSchema::TableSchema,
543            )))
544            .add(
545                Expr::col((
546                    InformationSchema::TableConstraints,
547                    InformationSchema::ConstraintType,
548                ))
549                .eq("FOREIGN KEY"),
550            ),
551    );
552    stmt
553}
554
555#[derive(DeriveIden)]
556enum PgType {
557    Table,
558    Typname,
559    Typnamespace,
560    Typelem,
561}
562
563#[derive(DeriveIden)]
564enum PgNamespace {
565    Table,
566    Oid,
567    Nspname,
568}
569
570fn query_pg_types<C>(db: &C) -> SelectStatement
571where
572    C: ConnectionTrait,
573{
574    let mut stmt = Query::select();
575    stmt.column(PgType::Typname)
576        .from(PgType::Table)
577        .join(
578            JoinType::LeftJoin,
579            PgNamespace::Table,
580            Expr::col((PgNamespace::Table, PgNamespace::Oid))
581                .equals((PgType::Table, PgType::Typnamespace)),
582        )
583        .cond_where(
584            Condition::all()
585                .add(get_current_schema(db).equals((PgNamespace::Table, PgNamespace::Nspname)))
586                .add(Expr::col((PgType::Table, PgType::Typelem)).eq(0)),
587        );
588    stmt
589}
590
591trait QueryTable {
592    type Statement;
593
594    fn table_name(self, table_name: DynIden) -> Self::Statement;
595}
596
597impl QueryTable for SelectStatement {
598    type Statement = SelectStatement;
599
600    fn table_name(mut self, table_name: DynIden) -> SelectStatement {
601        self.from(table_name);
602        self
603    }
604}
605
606impl QueryTable for sea_query::TableCreateStatement {
607    type Statement = sea_query::TableCreateStatement;
608
609    fn table_name(mut self, table_name: DynIden) -> sea_query::TableCreateStatement {
610        self.table(table_name);
611        self
612    }
613}
614
615impl<A> QueryTable for sea_orm::Insert<A>
616where
617    A: ActiveModelTrait,
618{
619    type Statement = sea_orm::Insert<A>;
620
621    fn table_name(mut self, table_name: DynIden) -> sea_orm::Insert<A> {
622        sea_orm::QueryTrait::query(&mut self).into_table(table_name);
623        self
624    }
625}
626
627impl<E> QueryTable for sea_orm::DeleteMany<E>
628where
629    E: EntityTrait,
630{
631    type Statement = sea_orm::DeleteMany<E>;
632
633    fn table_name(mut self, table_name: DynIden) -> sea_orm::DeleteMany<E> {
634        sea_orm::QueryTrait::query(&mut self).from_table(table_name);
635        self
636    }
637}