1use std::collections::HashSet;
2use std::fmt::Display;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6use tracing::info;
7
8use sea_orm::sea_query::{
9 self, Alias, Expr, ExprTrait, ForeignKey, IntoIden, JoinType, Order, Query, SelectStatement,
10 SimpleExpr, Table, extension::postgres::Type,
11};
12use sea_orm::{
13 ActiveModelTrait, ActiveValue, Condition, ConnectionTrait, DbBackend, DbErr, DeriveIden,
14 DynIden, EntityTrait, FromQueryResult, Iterable, QueryFilter, Schema, Statement,
15 TransactionTrait,
16};
17#[allow(unused_imports)]
18use sea_schema::probe::SchemaProbe;
19
20use super::{IntoSchemaManagerConnection, MigrationTrait, SchemaManager, seaql_migrations};
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum MigrationStatus {
25 Pending,
27 Applied,
29}
30
31impl Display for MigrationStatus {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 let status = match self {
34 MigrationStatus::Pending => "Pending",
35 MigrationStatus::Applied => "Applied",
36 };
37 write!(f, "{status}")
38 }
39}
40
41pub struct Migration {
42 migration: Box<dyn MigrationTrait>,
43 status: MigrationStatus,
44}
45
46impl Migration {
47 pub fn name(&self) -> &str {
49 self.migration.name()
50 }
51
52 pub fn status(&self) -> MigrationStatus {
54 self.status
55 }
56}
57
58#[async_trait::async_trait]
60pub trait MigratorTrait: Send {
61 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
63
64 fn migration_table_name() -> DynIden {
66 seaql_migrations::Entity.into_iden()
67 }
68
69 fn get_migration_files() -> Vec<Migration> {
71 Self::migrations()
72 .into_iter()
73 .map(|migration| Migration {
74 migration,
75 status: MigrationStatus::Pending,
76 })
77 .collect()
78 }
79
80 async fn get_migration_models<C>(db: &C) -> Result<Vec<seaql_migrations::Model>, DbErr>
82 where
83 C: ConnectionTrait,
84 {
85 Self::install(db).await?;
86 let stmt = Query::select()
87 .table_name(Self::migration_table_name())
88 .columns(seaql_migrations::Column::iter().map(IntoIden::into_iden))
89 .order_by(seaql_migrations::Column::Version, Order::Asc)
90 .to_owned();
91 let builder = db.get_database_backend();
92 seaql_migrations::Model::find_by_statement(builder.build(&stmt))
93 .all(db)
94 .await
95 }
96
97 async fn get_migration_with_status<C>(db: &C) -> Result<Vec<Migration>, DbErr>
99 where
100 C: ConnectionTrait,
101 {
102 Self::install(db).await?;
103 let mut migration_files = Self::get_migration_files();
104 let migration_models = Self::get_migration_models(db).await?;
105
106 let migration_in_db: HashSet<String> = migration_models
107 .into_iter()
108 .map(|model| model.version)
109 .collect();
110 let migration_in_fs: HashSet<String> = migration_files
111 .iter()
112 .map(|file| file.migration.name().to_string())
113 .collect();
114
115 let pending_migrations = &migration_in_fs - &migration_in_db;
116 for migration_file in migration_files.iter_mut() {
117 if !pending_migrations.contains(migration_file.migration.name()) {
118 migration_file.status = MigrationStatus::Applied;
119 }
120 }
121
122 let missing_migrations_in_fs = &migration_in_db - &migration_in_fs;
123 let errors: Vec<String> = missing_migrations_in_fs
124 .iter()
125 .map(|missing_migration| {
126 format!("Migration file of version '{missing_migration}' is missing, this migration has been applied but its file is missing")
127 }).collect();
128
129 if !errors.is_empty() {
130 Err(DbErr::Custom(errors.join("\n")))
131 } else {
132 Ok(migration_files)
133 }
134 }
135
136 async fn get_pending_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
138 where
139 C: ConnectionTrait,
140 {
141 Self::install(db).await?;
142 Ok(Self::get_migration_with_status(db)
143 .await?
144 .into_iter()
145 .filter(|file| file.status == MigrationStatus::Pending)
146 .collect())
147 }
148
149 async fn get_applied_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
151 where
152 C: ConnectionTrait,
153 {
154 Self::install(db).await?;
155 Ok(Self::get_migration_with_status(db)
156 .await?
157 .into_iter()
158 .filter(|file| file.status == MigrationStatus::Applied)
159 .collect())
160 }
161
162 async fn install<C>(db: &C) -> Result<(), DbErr>
164 where
165 C: ConnectionTrait,
166 {
167 let builder = db.get_database_backend();
168 let table_name = Self::migration_table_name();
169 let schema = Schema::new(builder);
170 let mut stmt = schema
171 .create_table_from_entity(seaql_migrations::Entity)
172 .table_name(table_name);
173 stmt.if_not_exists();
174 db.execute(&stmt).await.map(|_| ())
175 }
176
177 async fn status<C>(db: &C) -> Result<(), DbErr>
179 where
180 C: ConnectionTrait,
181 {
182 Self::install(db).await?;
183
184 info!("Checking migration status");
185
186 for Migration { migration, status } in Self::get_migration_with_status(db).await? {
187 info!("Migration '{}'... {}", migration.name(), status);
188 }
189
190 Ok(())
191 }
192
193 async fn fresh<'c, C>(db: C) -> Result<(), DbErr>
195 where
196 C: IntoSchemaManagerConnection<'c>,
197 {
198 exec_with_connection::<'_, _, _>(db, move |manager| {
199 Box::pin(async move { exec_fresh::<Self>(manager).await })
200 })
201 .await
202 }
203
204 async fn refresh<'c, C>(db: C) -> Result<(), DbErr>
206 where
207 C: IntoSchemaManagerConnection<'c>,
208 {
209 exec_with_connection::<'_, _, _>(db, move |manager| {
210 Box::pin(async move {
211 exec_down::<Self>(manager, None).await?;
212 exec_up::<Self>(manager, None).await
213 })
214 })
215 .await
216 }
217
218 async fn reset<'c, C>(db: C) -> Result<(), DbErr>
220 where
221 C: IntoSchemaManagerConnection<'c>,
222 {
223 exec_with_connection::<'_, _, _>(db, move |manager| {
224 Box::pin(async move {
225 exec_down::<Self>(manager, None).await?;
227
228 let mut stmt = Table::drop();
230 stmt.table(Self::migration_table_name())
231 .if_exists()
232 .cascade();
233 manager.drop_table(stmt).await?;
234
235 Ok(())
236 })
237 })
238 .await
239 }
240
241 async fn uninstall<'c, C>(db: C) -> Result<(), DbErr>
244 where
245 C: IntoSchemaManagerConnection<'c>,
246 {
247 exec_with_connection::<'_, _, _>(db, move |manager| {
248 Box::pin(async move {
249 let mut stmt = Table::drop();
250 stmt.table(Self::migration_table_name())
251 .if_exists()
252 .cascade();
253 manager.drop_table(stmt).await?;
254 Ok(())
255 })
256 })
257 .await
258 }
259
260 async fn up<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
262 where
263 C: IntoSchemaManagerConnection<'c>,
264 {
265 exec_with_connection::<'_, _, _>(db, move |manager| {
266 Box::pin(async move { exec_up::<Self>(manager, steps).await })
267 })
268 .await
269 }
270
271 async fn down<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
273 where
274 C: IntoSchemaManagerConnection<'c>,
275 {
276 exec_with_connection::<'_, _, _>(db, move |manager| {
277 Box::pin(async move { exec_down::<Self>(manager, steps).await })
278 })
279 .await
280 }
281}
282
283async fn exec_with_connection<'c, C, F>(db: C, f: F) -> Result<(), DbErr>
284where
285 C: IntoSchemaManagerConnection<'c>,
286 F: for<'b> Fn(
287 &'b SchemaManager<'_>,
288 ) -> Pin<Box<dyn Future<Output = Result<(), DbErr>> + Send + 'b>>,
289{
290 let db = db.into_schema_manager_connection();
291
292 match db.get_database_backend() {
293 DbBackend::Postgres => {
294 let transaction = db.begin().await?;
295 let manager = SchemaManager::new(&transaction);
296 f(&manager).await?;
297 transaction.commit().await
298 }
299 DbBackend::MySql | DbBackend::Sqlite => {
300 let manager = SchemaManager::new(db);
301 f(&manager).await
302 }
303 db => Err(DbErr::BackendNotSupported {
304 db: db.as_str(),
305 ctx: "exec_with_connection",
306 }),
307 }
308}
309
310async fn exec_fresh<M>(manager: &SchemaManager<'_>) -> Result<(), DbErr>
311where
312 M: MigratorTrait + ?Sized,
313{
314 let db = manager.get_connection();
315
316 M::install(db).await?;
317 let db_backend = db.get_database_backend();
318
319 if db_backend == DbBackend::Sqlite {
321 info!("Disabling foreign key check");
322 db.execute_raw(Statement::from_string(
323 db_backend,
324 "PRAGMA foreign_keys = OFF".to_owned(),
325 ))
326 .await?;
327 info!("Foreign key check disabled");
328 }
329
330 if db_backend == DbBackend::MySql {
332 info!("Dropping all foreign keys");
333 let stmt = query_mysql_foreign_keys(db);
334 let rows = db.query_all(&stmt).await?;
335 for row in rows.into_iter() {
336 let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?;
337 let table_name: String = row.try_get("", "TABLE_NAME")?;
338 info!(
339 "Dropping foreign key '{}' from table '{}'",
340 constraint_name, table_name
341 );
342 let mut stmt = ForeignKey::drop();
343 stmt.table(Alias::new(table_name.as_str()))
344 .name(constraint_name.as_str());
345 db.execute(&stmt).await?;
346 info!("Foreign key '{}' has been dropped", constraint_name);
347 }
348 info!("All foreign keys dropped");
349 }
350
351 let stmt = query_tables(db)?;
353 let rows = db.query_all(&stmt).await?;
354 for row in rows.into_iter() {
355 let table_name: String = row.try_get("", "table_name")?;
356 info!("Dropping table '{}'", table_name);
357 let mut stmt = Table::drop();
358 stmt.table(Alias::new(table_name.as_str()))
359 .if_exists()
360 .cascade();
361 db.execute(&stmt).await?;
362 info!("Table '{}' has been dropped", table_name);
363 }
364
365 if db_backend == DbBackend::Postgres {
367 info!("Dropping all types");
368 let stmt = query_pg_types(db);
369 let rows = db.query_all(&stmt).await?;
370 for row in rows {
371 let type_name: String = row.try_get("", "typname")?;
372 info!("Dropping type '{}'", type_name);
373 let mut stmt = Type::drop();
374 stmt.name(Alias::new(&type_name));
375 db.execute(&stmt).await?;
376 info!("Type '{}' has been dropped", type_name);
377 }
378 }
379
380 if db_backend == DbBackend::Sqlite {
382 info!("Restoring foreign key check");
383 db.execute_raw(Statement::from_string(
384 db_backend,
385 "PRAGMA foreign_keys = ON".to_owned(),
386 ))
387 .await?;
388 info!("Foreign key check restored");
389 }
390
391 exec_up::<M>(manager, None).await
393}
394
395async fn exec_up<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
396where
397 M: MigratorTrait + ?Sized,
398{
399 let db = manager.get_connection();
400
401 M::install(db).await?;
402
403 if let Some(steps) = steps {
404 info!("Applying {} pending migrations", steps);
405 } else {
406 info!("Applying all pending migrations");
407 }
408
409 let migrations = M::get_pending_migrations(db).await?.into_iter();
410 if migrations.len() == 0 {
411 info!("No pending migrations");
412 }
413 for Migration { migration, .. } in migrations {
414 if let Some(steps) = steps.as_mut() {
415 if steps == &0 {
416 break;
417 }
418 *steps -= 1;
419 }
420 info!("Applying migration '{}'", migration.name());
421 migration.up(manager).await?;
422 info!("Migration '{}' has been applied", migration.name());
423 let now = SystemTime::now()
424 .duration_since(SystemTime::UNIX_EPOCH)
425 .expect("SystemTime before UNIX EPOCH!");
426 seaql_migrations::Entity::insert(seaql_migrations::ActiveModel {
427 version: ActiveValue::Set(migration.name().to_owned()),
428 applied_at: ActiveValue::Set(now.as_secs() as i64),
429 })
430 .table_name(M::migration_table_name())
431 .exec(db)
432 .await?;
433 }
434
435 Ok(())
436}
437
438async fn exec_down<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
439where
440 M: MigratorTrait + ?Sized,
441{
442 let db = manager.get_connection();
443
444 M::install(db).await?;
445
446 if let Some(steps) = steps {
447 info!("Rolling back {} applied migrations", steps);
448 } else {
449 info!("Rolling back all applied migrations");
450 }
451
452 let migrations = M::get_applied_migrations(db).await?.into_iter().rev();
453 if migrations.len() == 0 {
454 info!("No applied migrations");
455 }
456 for Migration { migration, .. } in migrations {
457 if let Some(steps) = steps.as_mut() {
458 if steps == &0 {
459 break;
460 }
461 *steps -= 1;
462 }
463 info!("Rolling back migration '{}'", migration.name());
464 migration.down(manager).await?;
465 info!("Migration '{}' has been rollbacked", migration.name());
466 seaql_migrations::Entity::delete_many()
467 .filter(Expr::col(seaql_migrations::Column::Version).eq(migration.name()))
468 .table_name(M::migration_table_name())
469 .exec(db)
470 .await?;
471 }
472
473 Ok(())
474}
475
476fn query_tables<C>(db: &C) -> Result<SelectStatement, DbErr>
477where
478 C: ConnectionTrait,
479{
480 match db.get_database_backend() {
481 #[cfg(feature = "sqlx-mysql")]
482 DbBackend::MySql => Ok(sea_schema::mysql::MySql.query_tables()),
483 #[cfg(feature = "sqlx-postgres")]
484 DbBackend::Postgres => Ok(sea_schema::postgres::Postgres.query_tables()),
485 #[cfg(feature = "sqlx-sqlite")]
486 DbBackend::Sqlite => Ok(sea_schema::sqlite::Sqlite.query_tables()),
487 #[allow(unreachable_patterns)]
488 other => Err(DbErr::BackendNotSupported {
489 db: other.as_str(),
490 ctx: "query_tables",
491 }),
492 }
493}
494
495fn get_current_schema<C>(db: &C) -> SimpleExpr
497where
498 C: ConnectionTrait,
499{
500 match db.get_database_backend() {
501 #[cfg(feature = "sqlx-mysql")]
502 DbBackend::MySql => sea_schema::mysql::MySql::get_current_schema(),
503 #[cfg(feature = "sqlx-postgres")]
504 DbBackend::Postgres => sea_schema::postgres::Postgres::get_current_schema(),
505 #[cfg(feature = "sqlx-sqlite")]
506 DbBackend::Sqlite => sea_schema::sqlite::Sqlite::get_current_schema(),
507 #[allow(unreachable_patterns)]
508 other => panic!("{other:?} feature is off"),
509 }
510}
511
512#[derive(DeriveIden)]
513enum InformationSchema {
514 #[sea_orm(iden = "information_schema")]
515 Schema,
516 #[sea_orm(iden = "TABLE_NAME")]
517 TableName,
518 #[sea_orm(iden = "CONSTRAINT_NAME")]
519 ConstraintName,
520 TableConstraints,
521 TableSchema,
522 ConstraintType,
523}
524
525fn query_mysql_foreign_keys<C>(db: &C) -> SelectStatement
526where
527 C: ConnectionTrait,
528{
529 let mut stmt = Query::select();
530 stmt.columns([
531 InformationSchema::TableName,
532 InformationSchema::ConstraintName,
533 ])
534 .from((
535 InformationSchema::Schema,
536 InformationSchema::TableConstraints,
537 ))
538 .cond_where(
539 Condition::all()
540 .add(get_current_schema(db).equals((
541 InformationSchema::TableConstraints,
542 InformationSchema::TableSchema,
543 )))
544 .add(
545 Expr::col((
546 InformationSchema::TableConstraints,
547 InformationSchema::ConstraintType,
548 ))
549 .eq("FOREIGN KEY"),
550 ),
551 );
552 stmt
553}
554
555#[derive(DeriveIden)]
556enum PgType {
557 Table,
558 Typname,
559 Typnamespace,
560 Typelem,
561}
562
563#[derive(DeriveIden)]
564enum PgNamespace {
565 Table,
566 Oid,
567 Nspname,
568}
569
570fn query_pg_types<C>(db: &C) -> SelectStatement
571where
572 C: ConnectionTrait,
573{
574 let mut stmt = Query::select();
575 stmt.column(PgType::Typname)
576 .from(PgType::Table)
577 .join(
578 JoinType::LeftJoin,
579 PgNamespace::Table,
580 Expr::col((PgNamespace::Table, PgNamespace::Oid))
581 .equals((PgType::Table, PgType::Typnamespace)),
582 )
583 .cond_where(
584 Condition::all()
585 .add(get_current_schema(db).equals((PgNamespace::Table, PgNamespace::Nspname)))
586 .add(Expr::col((PgType::Table, PgType::Typelem)).eq(0)),
587 );
588 stmt
589}
590
591trait QueryTable {
592 type Statement;
593
594 fn table_name(self, table_name: DynIden) -> Self::Statement;
595}
596
597impl QueryTable for SelectStatement {
598 type Statement = SelectStatement;
599
600 fn table_name(mut self, table_name: DynIden) -> SelectStatement {
601 self.from(table_name);
602 self
603 }
604}
605
606impl QueryTable for sea_query::TableCreateStatement {
607 type Statement = sea_query::TableCreateStatement;
608
609 fn table_name(mut self, table_name: DynIden) -> sea_query::TableCreateStatement {
610 self.table(table_name);
611 self
612 }
613}
614
615impl<A> QueryTable for sea_orm::Insert<A>
616where
617 A: ActiveModelTrait,
618{
619 type Statement = sea_orm::Insert<A>;
620
621 fn table_name(mut self, table_name: DynIden) -> sea_orm::Insert<A> {
622 sea_orm::QueryTrait::query(&mut self).into_table(table_name);
623 self
624 }
625}
626
627impl<E> QueryTable for sea_orm::DeleteMany<E>
628where
629 E: EntityTrait,
630{
631 type Statement = sea_orm::DeleteMany<E>;
632
633 fn table_name(mut self, table_name: DynIden) -> sea_orm::DeleteMany<E> {
634 sea_orm::QueryTrait::query(&mut self).from_table(table_name);
635 self
636 }
637}