1use std::collections::HashSet;
2use std::fmt::Display;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6use tracing::info;
7
8use sea_orm::sea_query::{
9 self, Alias, Expr, ExprTrait, ForeignKey, IntoIden, JoinType, Order, Query, SelectStatement,
10 SimpleExpr, Table, extension::postgres::Type,
11};
12use sea_orm::{
13 ActiveModelTrait, ActiveValue, Condition, ConnectionTrait, DbBackend, DbErr, DeriveIden,
14 DynIden, EntityTrait, FromQueryResult, Iterable, QueryFilter, Schema, Statement,
15 TransactionTrait,
16};
17#[allow(unused_imports)]
18use sea_schema::probe::SchemaProbe;
19
20use super::{IntoSchemaManagerConnection, MigrationTrait, SchemaManager, seaql_migrations};
21
22#[derive(Copy, Clone, Debug, PartialEq, Eq)]
23pub enum MigrationStatus {
25 Pending,
27 Applied,
29}
30
31impl Display for MigrationStatus {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 let status = match self {
34 MigrationStatus::Pending => "Pending",
35 MigrationStatus::Applied => "Applied",
36 };
37 write!(f, "{status}")
38 }
39}
40
41pub struct Migration {
42 migration: Box<dyn MigrationTrait>,
43 status: MigrationStatus,
44}
45
46impl Migration {
47 pub fn name(&self) -> &str {
49 self.migration.name()
50 }
51
52 pub fn status(&self) -> MigrationStatus {
54 self.status
55 }
56}
57
58#[async_trait::async_trait]
60pub trait MigratorTrait: Send {
61 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
63
64 fn migration_table_name() -> DynIden {
66 seaql_migrations::Entity.into_iden()
67 }
68
69 fn get_migration_files() -> Vec<Migration> {
71 Self::migrations()
72 .into_iter()
73 .map(|migration| Migration {
74 migration,
75 status: MigrationStatus::Pending,
76 })
77 .collect()
78 }
79
80 async fn get_migration_models<C>(db: &C) -> Result<Vec<seaql_migrations::Model>, DbErr>
82 where
83 C: ConnectionTrait,
84 {
85 Self::install(db).await?;
86 let stmt = Query::select()
87 .table_name(Self::migration_table_name())
88 .columns(seaql_migrations::Column::iter().map(IntoIden::into_iden))
89 .order_by(seaql_migrations::Column::Version, Order::Asc)
90 .to_owned();
91 let builder = db.get_database_backend();
92 seaql_migrations::Model::find_by_statement(builder.build(&stmt))
93 .all(db)
94 .await
95 }
96
97 async fn get_migration_with_status<C>(db: &C) -> Result<Vec<Migration>, DbErr>
99 where
100 C: ConnectionTrait,
101 {
102 Self::install(db).await?;
103 let mut migration_files = Self::get_migration_files();
104 let migration_models = Self::get_migration_models(db).await?;
105
106 let migration_in_db: HashSet<String> = migration_models
107 .into_iter()
108 .map(|model| model.version)
109 .collect();
110 let migration_in_fs: HashSet<String> = migration_files
111 .iter()
112 .map(|file| file.migration.name().to_string())
113 .collect();
114
115 let pending_migrations = &migration_in_fs - &migration_in_db;
116 for migration_file in migration_files.iter_mut() {
117 if !pending_migrations.contains(migration_file.migration.name()) {
118 migration_file.status = MigrationStatus::Applied;
119 }
120 }
121
122 let missing_migrations_in_fs = &migration_in_db - &migration_in_fs;
123 let errors: Vec<String> = missing_migrations_in_fs
124 .iter()
125 .map(|missing_migration| {
126 format!("Migration file of version '{missing_migration}' is missing, this migration has been applied but its file is missing")
127 }).collect();
128
129 if !errors.is_empty() {
130 Err(DbErr::Custom(errors.join("\n")))
131 } else {
132 Ok(migration_files)
133 }
134 }
135
136 async fn get_pending_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
138 where
139 C: ConnectionTrait,
140 {
141 Self::install(db).await?;
142 Ok(Self::get_migration_with_status(db)
143 .await?
144 .into_iter()
145 .filter(|file| file.status == MigrationStatus::Pending)
146 .collect())
147 }
148
149 async fn get_applied_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
151 where
152 C: ConnectionTrait,
153 {
154 Self::install(db).await?;
155 Ok(Self::get_migration_with_status(db)
156 .await?
157 .into_iter()
158 .filter(|file| file.status == MigrationStatus::Applied)
159 .collect())
160 }
161
162 async fn install<C>(db: &C) -> Result<(), DbErr>
164 where
165 C: ConnectionTrait,
166 {
167 let builder = db.get_database_backend();
168 let table_name = Self::migration_table_name();
169 let schema = Schema::new(builder);
170 let mut stmt = schema
171 .create_table_from_entity(seaql_migrations::Entity)
172 .table_name(table_name);
173 stmt.if_not_exists();
174 db.execute(&stmt).await.map(|_| ())
175 }
176
177 async fn status<C>(db: &C) -> Result<(), DbErr>
179 where
180 C: ConnectionTrait,
181 {
182 Self::install(db).await?;
183
184 info!("Checking migration status");
185
186 for Migration { migration, status } in Self::get_migration_with_status(db).await? {
187 info!("Migration '{}'... {}", migration.name(), status);
188 }
189
190 Ok(())
191 }
192
193 async fn fresh<'c, C>(db: C) -> Result<(), DbErr>
195 where
196 C: IntoSchemaManagerConnection<'c>,
197 {
198 exec_with_connection::<'_, _, _>(db, move |manager| {
199 Box::pin(async move { exec_fresh::<Self>(manager).await })
200 })
201 .await
202 }
203
204 async fn refresh<'c, C>(db: C) -> Result<(), DbErr>
206 where
207 C: IntoSchemaManagerConnection<'c>,
208 {
209 exec_with_connection::<'_, _, _>(db, move |manager| {
210 Box::pin(async move {
211 exec_down::<Self>(manager, None).await?;
212 exec_up::<Self>(manager, None).await
213 })
214 })
215 .await
216 }
217
218 async fn reset<'c, C>(db: C) -> Result<(), DbErr>
220 where
221 C: IntoSchemaManagerConnection<'c>,
222 {
223 exec_with_connection::<'_, _, _>(db, move |manager| {
224 Box::pin(async move { exec_down::<Self>(manager, None).await })
225 })
226 .await
227 }
228
229 async fn up<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
231 where
232 C: IntoSchemaManagerConnection<'c>,
233 {
234 exec_with_connection::<'_, _, _>(db, move |manager| {
235 Box::pin(async move { exec_up::<Self>(manager, steps).await })
236 })
237 .await
238 }
239
240 async fn down<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
242 where
243 C: IntoSchemaManagerConnection<'c>,
244 {
245 exec_with_connection::<'_, _, _>(db, move |manager| {
246 Box::pin(async move { exec_down::<Self>(manager, steps).await })
247 })
248 .await
249 }
250}
251
252async fn exec_with_connection<'c, C, F>(db: C, f: F) -> Result<(), DbErr>
253where
254 C: IntoSchemaManagerConnection<'c>,
255 F: for<'b> Fn(
256 &'b SchemaManager<'_>,
257 ) -> Pin<Box<dyn Future<Output = Result<(), DbErr>> + Send + 'b>>,
258{
259 let db = db.into_schema_manager_connection();
260
261 match db.get_database_backend() {
262 DbBackend::Postgres => {
263 let transaction = db.begin().await?;
264 let manager = SchemaManager::new(&transaction);
265 f(&manager).await?;
266 transaction.commit().await
267 }
268 DbBackend::MySql | DbBackend::Sqlite => {
269 let manager = SchemaManager::new(db);
270 f(&manager).await
271 }
272 db => Err(DbErr::BackendNotSupported {
273 db: db.as_str(),
274 ctx: "exec_with_connection",
275 }),
276 }
277}
278
279async fn exec_fresh<M>(manager: &SchemaManager<'_>) -> Result<(), DbErr>
280where
281 M: MigratorTrait + ?Sized,
282{
283 let db = manager.get_connection();
284
285 M::install(db).await?;
286 let db_backend = db.get_database_backend();
287
288 if db_backend == DbBackend::Sqlite {
290 info!("Disabling foreign key check");
291 db.execute_raw(Statement::from_string(
292 db_backend,
293 "PRAGMA foreign_keys = OFF".to_owned(),
294 ))
295 .await?;
296 info!("Foreign key check disabled");
297 }
298
299 if db_backend == DbBackend::MySql {
301 info!("Dropping all foreign keys");
302 let stmt = query_mysql_foreign_keys(db);
303 let rows = db.query_all(&stmt).await?;
304 for row in rows.into_iter() {
305 let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?;
306 let table_name: String = row.try_get("", "TABLE_NAME")?;
307 info!(
308 "Dropping foreign key '{}' from table '{}'",
309 constraint_name, table_name
310 );
311 let mut stmt = ForeignKey::drop();
312 stmt.table(Alias::new(table_name.as_str()))
313 .name(constraint_name.as_str());
314 db.execute(&stmt).await?;
315 info!("Foreign key '{}' has been dropped", constraint_name);
316 }
317 info!("All foreign keys dropped");
318 }
319
320 let stmt = query_tables(db)?;
322 let rows = db.query_all(&stmt).await?;
323 for row in rows.into_iter() {
324 let table_name: String = row.try_get("", "table_name")?;
325 info!("Dropping table '{}'", table_name);
326 let mut stmt = Table::drop();
327 stmt.table(Alias::new(table_name.as_str()))
328 .if_exists()
329 .cascade();
330 db.execute(&stmt).await?;
331 info!("Table '{}' has been dropped", table_name);
332 }
333
334 if db_backend == DbBackend::Postgres {
336 info!("Dropping all types");
337 let stmt = query_pg_types(db);
338 let rows = db.query_all(&stmt).await?;
339 for row in rows {
340 let type_name: String = row.try_get("", "typname")?;
341 info!("Dropping type '{}'", type_name);
342 let mut stmt = Type::drop();
343 stmt.name(Alias::new(&type_name));
344 db.execute(&stmt).await?;
345 info!("Type '{}' has been dropped", type_name);
346 }
347 }
348
349 if db_backend == DbBackend::Sqlite {
351 info!("Restoring foreign key check");
352 db.execute_raw(Statement::from_string(
353 db_backend,
354 "PRAGMA foreign_keys = ON".to_owned(),
355 ))
356 .await?;
357 info!("Foreign key check restored");
358 }
359
360 exec_up::<M>(manager, None).await
362}
363
364async fn exec_up<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
365where
366 M: MigratorTrait + ?Sized,
367{
368 let db = manager.get_connection();
369
370 M::install(db).await?;
371
372 if let Some(steps) = steps {
373 info!("Applying {} pending migrations", steps);
374 } else {
375 info!("Applying all pending migrations");
376 }
377
378 let migrations = M::get_pending_migrations(db).await?.into_iter();
379 if migrations.len() == 0 {
380 info!("No pending migrations");
381 }
382 for Migration { migration, .. } in migrations {
383 if let Some(steps) = steps.as_mut() {
384 if steps == &0 {
385 break;
386 }
387 *steps -= 1;
388 }
389 info!("Applying migration '{}'", migration.name());
390 migration.up(manager).await?;
391 info!("Migration '{}' has been applied", migration.name());
392 let now = SystemTime::now()
393 .duration_since(SystemTime::UNIX_EPOCH)
394 .expect("SystemTime before UNIX EPOCH!");
395 seaql_migrations::Entity::insert(seaql_migrations::ActiveModel {
396 version: ActiveValue::Set(migration.name().to_owned()),
397 applied_at: ActiveValue::Set(now.as_secs() as i64),
398 })
399 .table_name(M::migration_table_name())
400 .exec(db)
401 .await?;
402 }
403
404 Ok(())
405}
406
407async fn exec_down<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
408where
409 M: MigratorTrait + ?Sized,
410{
411 let db = manager.get_connection();
412
413 M::install(db).await?;
414
415 if let Some(steps) = steps {
416 info!("Rolling back {} applied migrations", steps);
417 } else {
418 info!("Rolling back all applied migrations");
419 }
420
421 let migrations = M::get_applied_migrations(db).await?.into_iter().rev();
422 if migrations.len() == 0 {
423 info!("No applied migrations");
424 }
425 for Migration { migration, .. } in migrations {
426 if let Some(steps) = steps.as_mut() {
427 if steps == &0 {
428 break;
429 }
430 *steps -= 1;
431 }
432 info!("Rolling back migration '{}'", migration.name());
433 migration.down(manager).await?;
434 info!("Migration '{}' has been rollbacked", migration.name());
435 seaql_migrations::Entity::delete_many()
436 .filter(Expr::col(seaql_migrations::Column::Version).eq(migration.name()))
437 .table_name(M::migration_table_name())
438 .exec(db)
439 .await?;
440 }
441
442 Ok(())
443}
444
445fn query_tables<C>(db: &C) -> Result<SelectStatement, DbErr>
446where
447 C: ConnectionTrait,
448{
449 match db.get_database_backend() {
450 #[cfg(feature = "sqlx-mysql")]
451 DbBackend::MySql => Ok(sea_schema::mysql::MySql.query_tables()),
452 #[cfg(feature = "sqlx-postgres")]
453 DbBackend::Postgres => Ok(sea_schema::postgres::Postgres.query_tables()),
454 #[cfg(feature = "sqlx-sqlite")]
455 DbBackend::Sqlite => Ok(sea_schema::sqlite::Sqlite.query_tables()),
456 #[allow(unreachable_patterns)]
457 other => Err(DbErr::BackendNotSupported {
458 db: other.as_str(),
459 ctx: "query_tables",
460 }),
461 }
462}
463
464fn get_current_schema<C>(db: &C) -> SimpleExpr
466where
467 C: ConnectionTrait,
468{
469 match db.get_database_backend() {
470 #[cfg(feature = "sqlx-mysql")]
471 DbBackend::MySql => sea_schema::mysql::MySql::get_current_schema(),
472 #[cfg(feature = "sqlx-postgres")]
473 DbBackend::Postgres => sea_schema::postgres::Postgres::get_current_schema(),
474 #[cfg(feature = "sqlx-sqlite")]
475 DbBackend::Sqlite => sea_schema::sqlite::Sqlite::get_current_schema(),
476 #[allow(unreachable_patterns)]
477 other => panic!("{other:?} feature is off"),
478 }
479}
480
481#[derive(DeriveIden)]
482enum InformationSchema {
483 #[sea_orm(iden = "information_schema")]
484 Schema,
485 #[sea_orm(iden = "TABLE_NAME")]
486 TableName,
487 #[sea_orm(iden = "CONSTRAINT_NAME")]
488 ConstraintName,
489 TableConstraints,
490 TableSchema,
491 ConstraintType,
492}
493
494fn query_mysql_foreign_keys<C>(db: &C) -> SelectStatement
495where
496 C: ConnectionTrait,
497{
498 let mut stmt = Query::select();
499 stmt.columns([
500 InformationSchema::TableName,
501 InformationSchema::ConstraintName,
502 ])
503 .from((
504 InformationSchema::Schema,
505 InformationSchema::TableConstraints,
506 ))
507 .cond_where(
508 Condition::all()
509 .add(get_current_schema(db).equals((
510 InformationSchema::TableConstraints,
511 InformationSchema::TableSchema,
512 )))
513 .add(
514 Expr::col((
515 InformationSchema::TableConstraints,
516 InformationSchema::ConstraintType,
517 ))
518 .eq("FOREIGN KEY"),
519 ),
520 );
521 stmt
522}
523
524#[derive(DeriveIden)]
525enum PgType {
526 Table,
527 Typname,
528 Typnamespace,
529 Typelem,
530}
531
532#[derive(DeriveIden)]
533enum PgNamespace {
534 Table,
535 Oid,
536 Nspname,
537}
538
539fn query_pg_types<C>(db: &C) -> SelectStatement
540where
541 C: ConnectionTrait,
542{
543 let mut stmt = Query::select();
544 stmt.column(PgType::Typname)
545 .from(PgType::Table)
546 .join(
547 JoinType::LeftJoin,
548 PgNamespace::Table,
549 Expr::col((PgNamespace::Table, PgNamespace::Oid))
550 .equals((PgType::Table, PgType::Typnamespace)),
551 )
552 .cond_where(
553 Condition::all()
554 .add(get_current_schema(db).equals((PgNamespace::Table, PgNamespace::Nspname)))
555 .add(Expr::col((PgType::Table, PgType::Typelem)).eq(0)),
556 );
557 stmt
558}
559
560trait QueryTable {
561 type Statement;
562
563 fn table_name(self, table_name: DynIden) -> Self::Statement;
564}
565
566impl QueryTable for SelectStatement {
567 type Statement = SelectStatement;
568
569 fn table_name(mut self, table_name: DynIden) -> SelectStatement {
570 self.from(table_name);
571 self
572 }
573}
574
575impl QueryTable for sea_query::TableCreateStatement {
576 type Statement = sea_query::TableCreateStatement;
577
578 fn table_name(mut self, table_name: DynIden) -> sea_query::TableCreateStatement {
579 self.table(table_name);
580 self
581 }
582}
583
584impl<A> QueryTable for sea_orm::Insert<A>
585where
586 A: ActiveModelTrait,
587{
588 type Statement = sea_orm::Insert<A>;
589
590 fn table_name(mut self, table_name: DynIden) -> sea_orm::Insert<A> {
591 sea_orm::QueryTrait::query(&mut self).into_table(table_name);
592 self
593 }
594}
595
596impl<E> QueryTable for sea_orm::DeleteMany<E>
597where
598 E: EntityTrait,
599{
600 type Statement = sea_orm::DeleteMany<E>;
601
602 fn table_name(mut self, table_name: DynIden) -> sea_orm::DeleteMany<E> {
603 sea_orm::QueryTrait::query(&mut self).from_table(table_name);
604 self
605 }
606}