sea_orm_migration/
migrator.rs

1use std::collections::HashSet;
2use std::fmt::Display;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6use tracing::info;
7
8use sea_orm::sea_query::{
9    self, extension::postgres::Type, Alias, Expr, ForeignKey, IntoIden, JoinType, Order, Query,
10    SelectStatement, SimpleExpr, Table,
11};
12use sea_orm::{
13    ActiveModelTrait, ActiveValue, Condition, ConnectionTrait, DbBackend, DbErr, DeriveIden,
14    DynIden, EntityTrait, FromQueryResult, Iterable, QueryFilter, Schema, Statement,
15    TransactionTrait,
16};
17use sea_schema::{mysql::MySql, postgres::Postgres, probe::SchemaProbe, sqlite::Sqlite};
18
19use super::{seaql_migrations, IntoSchemaManagerConnection, MigrationTrait, SchemaManager};
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
22/// Status of migration
23pub enum MigrationStatus {
24    /// Not yet applied
25    Pending,
26    /// Applied
27    Applied,
28}
29
30impl Display for MigrationStatus {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        let status = match self {
33            MigrationStatus::Pending => "Pending",
34            MigrationStatus::Applied => "Applied",
35        };
36        write!(f, "{status}")
37    }
38}
39
40pub struct Migration {
41    migration: Box<dyn MigrationTrait>,
42    status: MigrationStatus,
43}
44
45impl Migration {
46    /// Get migration name from MigrationName trait implementation
47    pub fn name(&self) -> &str {
48        self.migration.name()
49    }
50
51    /// Get migration status
52    pub fn status(&self) -> MigrationStatus {
53        self.status
54    }
55}
56
57/// Performing migrations on a database
58#[async_trait::async_trait]
59pub trait MigratorTrait: Send {
60    /// Vector of migrations in time sequence
61    fn migrations() -> Vec<Box<dyn MigrationTrait>>;
62
63    /// Name of the migration table, it is `seaql_migrations` by default
64    fn migration_table_name() -> DynIden {
65        seaql_migrations::Entity.into_iden()
66    }
67
68    /// Get list of migrations wrapped in `Migration` struct
69    fn get_migration_files() -> Vec<Migration> {
70        Self::migrations()
71            .into_iter()
72            .map(|migration| Migration {
73                migration,
74                status: MigrationStatus::Pending,
75            })
76            .collect()
77    }
78
79    /// Get list of applied migrations from database
80    async fn get_migration_models<C>(db: &C) -> Result<Vec<seaql_migrations::Model>, DbErr>
81    where
82        C: ConnectionTrait,
83    {
84        Self::install(db).await?;
85        let stmt = Query::select()
86            .table_name(Self::migration_table_name())
87            .columns(seaql_migrations::Column::iter().map(IntoIden::into_iden))
88            .order_by(seaql_migrations::Column::Version, Order::Asc)
89            .to_owned();
90        let builder = db.get_database_backend();
91        seaql_migrations::Model::find_by_statement(builder.build(&stmt))
92            .all(db)
93            .await
94    }
95
96    /// Get list of migrations with status
97    async fn get_migration_with_status<C>(db: &C) -> Result<Vec<Migration>, DbErr>
98    where
99        C: ConnectionTrait,
100    {
101        Self::install(db).await?;
102        let mut migration_files = Self::get_migration_files();
103        let migration_models = Self::get_migration_models(db).await?;
104
105        let migration_in_db: HashSet<String> = migration_models
106            .into_iter()
107            .map(|model| model.version)
108            .collect();
109        let migration_in_fs: HashSet<String> = migration_files
110            .iter()
111            .map(|file| file.migration.name().to_string())
112            .collect();
113
114        let pending_migrations = &migration_in_fs - &migration_in_db;
115        for migration_file in migration_files.iter_mut() {
116            if !pending_migrations.contains(migration_file.migration.name()) {
117                migration_file.status = MigrationStatus::Applied;
118            }
119        }
120
121        let missing_migrations_in_fs = &migration_in_db - &migration_in_fs;
122        let errors: Vec<String> = missing_migrations_in_fs
123            .iter()
124            .map(|missing_migration| {
125                format!("Migration file of version '{missing_migration}' is missing, this migration has been applied but its file is missing")
126            }).collect();
127
128        if !errors.is_empty() {
129            Err(DbErr::Custom(errors.join("\n")))
130        } else {
131            Ok(migration_files)
132        }
133    }
134
135    /// Get list of pending migrations
136    async fn get_pending_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
137    where
138        C: ConnectionTrait,
139    {
140        Self::install(db).await?;
141        Ok(Self::get_migration_with_status(db)
142            .await?
143            .into_iter()
144            .filter(|file| file.status == MigrationStatus::Pending)
145            .collect())
146    }
147
148    /// Get list of applied migrations
149    async fn get_applied_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
150    where
151        C: ConnectionTrait,
152    {
153        Self::install(db).await?;
154        Ok(Self::get_migration_with_status(db)
155            .await?
156            .into_iter()
157            .filter(|file| file.status == MigrationStatus::Applied)
158            .collect())
159    }
160
161    /// Create migration table `seaql_migrations` in the database
162    async fn install<C>(db: &C) -> Result<(), DbErr>
163    where
164        C: ConnectionTrait,
165    {
166        let builder = db.get_database_backend();
167        let table_name = Self::migration_table_name();
168        let schema = Schema::new(builder);
169        let mut stmt = schema
170            .create_table_from_entity(seaql_migrations::Entity)
171            .table_name(table_name);
172        stmt.if_not_exists();
173        db.execute(builder.build(&stmt)).await.map(|_| ())
174    }
175
176    /// Check the status of all migrations
177    async fn status<C>(db: &C) -> Result<(), DbErr>
178    where
179        C: ConnectionTrait,
180    {
181        Self::install(db).await?;
182
183        info!("Checking migration status");
184
185        for Migration { migration, status } in Self::get_migration_with_status(db).await? {
186            info!("Migration '{}'... {}", migration.name(), status);
187        }
188
189        Ok(())
190    }
191
192    /// Drop all tables from the database, then reapply all migrations
193    async fn fresh<'c, C>(db: C) -> Result<(), DbErr>
194    where
195        C: IntoSchemaManagerConnection<'c>,
196    {
197        exec_with_connection::<'_, _, _>(db, move |manager| {
198            Box::pin(async move { exec_fresh::<Self>(manager).await })
199        })
200        .await
201    }
202
203    /// Rollback all applied migrations, then reapply all migrations
204    async fn refresh<'c, C>(db: C) -> Result<(), DbErr>
205    where
206        C: IntoSchemaManagerConnection<'c>,
207    {
208        exec_with_connection::<'_, _, _>(db, move |manager| {
209            Box::pin(async move {
210                exec_down::<Self>(manager, None).await?;
211                exec_up::<Self>(manager, None).await
212            })
213        })
214        .await
215    }
216
217    /// Rollback all applied migrations
218    async fn reset<'c, C>(db: C) -> Result<(), DbErr>
219    where
220        C: IntoSchemaManagerConnection<'c>,
221    {
222        exec_with_connection::<'_, _, _>(db, move |manager| {
223            Box::pin(async move { exec_down::<Self>(manager, None).await })
224        })
225        .await
226    }
227
228    /// Apply pending migrations
229    async fn up<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
230    where
231        C: IntoSchemaManagerConnection<'c>,
232    {
233        exec_with_connection::<'_, _, _>(db, move |manager| {
234            Box::pin(async move { exec_up::<Self>(manager, steps).await })
235        })
236        .await
237    }
238
239    /// Rollback applied migrations
240    async fn down<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
241    where
242        C: IntoSchemaManagerConnection<'c>,
243    {
244        exec_with_connection::<'_, _, _>(db, move |manager| {
245            Box::pin(async move { exec_down::<Self>(manager, steps).await })
246        })
247        .await
248    }
249}
250
251async fn exec_with_connection<'c, C, F>(db: C, f: F) -> Result<(), DbErr>
252where
253    C: IntoSchemaManagerConnection<'c>,
254    F: for<'b> Fn(
255        &'b SchemaManager<'_>,
256    ) -> Pin<Box<dyn Future<Output = Result<(), DbErr>> + Send + 'b>>,
257{
258    let db = db.into_schema_manager_connection();
259
260    match db.get_database_backend() {
261        DbBackend::Postgres => {
262            let transaction = db.begin().await?;
263            let manager = SchemaManager::new(&transaction);
264            f(&manager).await?;
265            transaction.commit().await
266        }
267        DbBackend::MySql | DbBackend::Sqlite => {
268            let manager = SchemaManager::new(db);
269            f(&manager).await
270        }
271    }
272}
273
274async fn exec_fresh<M>(manager: &SchemaManager<'_>) -> Result<(), DbErr>
275where
276    M: MigratorTrait + ?Sized,
277{
278    let db = manager.get_connection();
279
280    M::install(db).await?;
281    let db_backend = db.get_database_backend();
282
283    // Temporarily disable the foreign key check
284    if db_backend == DbBackend::Sqlite {
285        info!("Disabling foreign key check");
286        db.execute(Statement::from_string(
287            db_backend,
288            "PRAGMA foreign_keys = OFF".to_owned(),
289        ))
290        .await?;
291        info!("Foreign key check disabled");
292    }
293
294    // Drop all foreign keys
295    if db_backend == DbBackend::MySql {
296        info!("Dropping all foreign keys");
297        let stmt = query_mysql_foreign_keys(db);
298        let rows = db.query_all(db_backend.build(&stmt)).await?;
299        for row in rows.into_iter() {
300            let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?;
301            let table_name: String = row.try_get("", "TABLE_NAME")?;
302            info!(
303                "Dropping foreign key '{}' from table '{}'",
304                constraint_name, table_name
305            );
306            let mut stmt = ForeignKey::drop();
307            stmt.table(Alias::new(table_name.as_str()))
308                .name(constraint_name.as_str());
309            db.execute(db_backend.build(&stmt)).await?;
310            info!("Foreign key '{}' has been dropped", constraint_name);
311        }
312        info!("All foreign keys dropped");
313    }
314
315    // Drop all tables
316    let stmt = query_tables(db).await;
317    let rows = db.query_all(db_backend.build(&stmt)).await?;
318    for row in rows.into_iter() {
319        let table_name: String = row.try_get("", "table_name")?;
320        info!("Dropping table '{}'", table_name);
321        let mut stmt = Table::drop();
322        stmt.table(Alias::new(table_name.as_str()))
323            .if_exists()
324            .cascade();
325        db.execute(db_backend.build(&stmt)).await?;
326        info!("Table '{}' has been dropped", table_name);
327    }
328
329    // Drop all types
330    if db_backend == DbBackend::Postgres {
331        info!("Dropping all types");
332        let stmt = query_pg_types(db);
333        let rows = db.query_all(db_backend.build(&stmt)).await?;
334        for row in rows {
335            let type_name: String = row.try_get("", "typname")?;
336            info!("Dropping type '{}'", type_name);
337            let mut stmt = Type::drop();
338            stmt.name(Alias::new(&type_name));
339            db.execute(db_backend.build(&stmt)).await?;
340            info!("Type '{}' has been dropped", type_name);
341        }
342    }
343
344    // Restore the foreign key check
345    if db_backend == DbBackend::Sqlite {
346        info!("Restoring foreign key check");
347        db.execute(Statement::from_string(
348            db_backend,
349            "PRAGMA foreign_keys = ON".to_owned(),
350        ))
351        .await?;
352        info!("Foreign key check restored");
353    }
354
355    // Reapply all migrations
356    exec_up::<M>(manager, None).await
357}
358
359async fn exec_up<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
360where
361    M: MigratorTrait + ?Sized,
362{
363    let db = manager.get_connection();
364
365    M::install(db).await?;
366
367    if let Some(steps) = steps {
368        info!("Applying {} pending migrations", steps);
369    } else {
370        info!("Applying all pending migrations");
371    }
372
373    let migrations = M::get_pending_migrations(db).await?.into_iter();
374    if migrations.len() == 0 {
375        info!("No pending migrations");
376    }
377    for Migration { migration, .. } in migrations {
378        if let Some(steps) = steps.as_mut() {
379            if steps == &0 {
380                break;
381            }
382            *steps -= 1;
383        }
384        info!("Applying migration '{}'", migration.name());
385        migration.up(manager).await?;
386        info!("Migration '{}' has been applied", migration.name());
387        let now = SystemTime::now()
388            .duration_since(SystemTime::UNIX_EPOCH)
389            .expect("SystemTime before UNIX EPOCH!");
390        seaql_migrations::Entity::insert(seaql_migrations::ActiveModel {
391            version: ActiveValue::Set(migration.name().to_owned()),
392            applied_at: ActiveValue::Set(now.as_secs() as i64),
393        })
394        .table_name(M::migration_table_name())
395        .exec(db)
396        .await?;
397    }
398
399    Ok(())
400}
401
402async fn exec_down<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
403where
404    M: MigratorTrait + ?Sized,
405{
406    let db = manager.get_connection();
407
408    M::install(db).await?;
409
410    if let Some(steps) = steps {
411        info!("Rolling back {} applied migrations", steps);
412    } else {
413        info!("Rolling back all applied migrations");
414    }
415
416    let migrations = M::get_applied_migrations(db).await?.into_iter().rev();
417    if migrations.len() == 0 {
418        info!("No applied migrations");
419    }
420    for Migration { migration, .. } in migrations {
421        if let Some(steps) = steps.as_mut() {
422            if steps == &0 {
423                break;
424            }
425            *steps -= 1;
426        }
427        info!("Rolling back migration '{}'", migration.name());
428        migration.down(manager).await?;
429        info!("Migration '{}' has been rollbacked", migration.name());
430        seaql_migrations::Entity::delete_many()
431            .filter(Expr::col(seaql_migrations::Column::Version).eq(migration.name()))
432            .table_name(M::migration_table_name())
433            .exec(db)
434            .await?;
435    }
436
437    Ok(())
438}
439
440async fn query_tables<C>(db: &C) -> SelectStatement
441where
442    C: ConnectionTrait,
443{
444    match db.get_database_backend() {
445        DbBackend::MySql => MySql.query_tables(),
446        DbBackend::Postgres => Postgres.query_tables(),
447        DbBackend::Sqlite => Sqlite.query_tables(),
448    }
449}
450
451fn get_current_schema<C>(db: &C) -> SimpleExpr
452where
453    C: ConnectionTrait,
454{
455    match db.get_database_backend() {
456        DbBackend::MySql => MySql::get_current_schema(),
457        DbBackend::Postgres => Postgres::get_current_schema(),
458        DbBackend::Sqlite => unimplemented!(),
459    }
460}
461
462#[derive(DeriveIden)]
463enum InformationSchema {
464    #[sea_orm(iden = "information_schema")]
465    Schema,
466    #[sea_orm(iden = "TABLE_NAME")]
467    TableName,
468    #[sea_orm(iden = "CONSTRAINT_NAME")]
469    ConstraintName,
470    TableConstraints,
471    TableSchema,
472    ConstraintType,
473}
474
475fn query_mysql_foreign_keys<C>(db: &C) -> SelectStatement
476where
477    C: ConnectionTrait,
478{
479    let mut stmt = Query::select();
480    stmt.columns([
481        InformationSchema::TableName,
482        InformationSchema::ConstraintName,
483    ])
484    .from((
485        InformationSchema::Schema,
486        InformationSchema::TableConstraints,
487    ))
488    .cond_where(
489        Condition::all()
490            .add(Expr::expr(get_current_schema(db)).equals((
491                InformationSchema::TableConstraints,
492                InformationSchema::TableSchema,
493            )))
494            .add(
495                Expr::col((
496                    InformationSchema::TableConstraints,
497                    InformationSchema::ConstraintType,
498                ))
499                .eq("FOREIGN KEY"),
500            ),
501    );
502    stmt
503}
504
505#[derive(DeriveIden)]
506enum PgType {
507    Table,
508    Typname,
509    Typnamespace,
510    Typelem,
511}
512
513#[derive(DeriveIden)]
514enum PgNamespace {
515    Table,
516    Oid,
517    Nspname,
518}
519
520fn query_pg_types<C>(db: &C) -> SelectStatement
521where
522    C: ConnectionTrait,
523{
524    let mut stmt = Query::select();
525    stmt.column(PgType::Typname)
526        .from(PgType::Table)
527        .join(
528            JoinType::LeftJoin,
529            PgNamespace::Table,
530            Expr::col((PgNamespace::Table, PgNamespace::Oid))
531                .equals((PgType::Table, PgType::Typnamespace)),
532        )
533        .cond_where(
534            Condition::all()
535                .add(
536                    Expr::expr(get_current_schema(db))
537                        .equals((PgNamespace::Table, PgNamespace::Nspname)),
538                )
539                .add(Expr::col((PgType::Table, PgType::Typelem)).eq(0)),
540        );
541    stmt
542}
543
544trait QueryTable {
545    type Statement;
546
547    fn table_name(self, table_name: DynIden) -> Self::Statement;
548}
549
550impl QueryTable for SelectStatement {
551    type Statement = SelectStatement;
552
553    fn table_name(mut self, table_name: DynIden) -> SelectStatement {
554        self.from(table_name);
555        self
556    }
557}
558
559impl QueryTable for sea_query::TableCreateStatement {
560    type Statement = sea_query::TableCreateStatement;
561
562    fn table_name(mut self, table_name: DynIden) -> sea_query::TableCreateStatement {
563        self.table(table_name);
564        self
565    }
566}
567
568impl<A> QueryTable for sea_orm::Insert<A>
569where
570    A: ActiveModelTrait,
571{
572    type Statement = sea_orm::Insert<A>;
573
574    fn table_name(mut self, table_name: DynIden) -> sea_orm::Insert<A> {
575        sea_orm::QueryTrait::query(&mut self).into_table(table_name);
576        self
577    }
578}
579
580impl<E> QueryTable for sea_orm::DeleteMany<E>
581where
582    E: EntityTrait,
583{
584    type Statement = sea_orm::DeleteMany<E>;
585
586    fn table_name(mut self, table_name: DynIden) -> sea_orm::DeleteMany<E> {
587        sea_orm::QueryTrait::query(&mut self).from_table(table_name);
588        self
589    }
590}