1use std::collections::HashSet;
2use std::fmt::Display;
3use std::future::Future;
4use std::pin::Pin;
5use std::time::SystemTime;
6use tracing::info;
7
8use sea_orm::sea_query::{
9 self, extension::postgres::Type, Alias, Expr, ForeignKey, IntoIden, JoinType, Order, Query,
10 SelectStatement, SimpleExpr, Table,
11};
12use sea_orm::{
13 ActiveModelTrait, ActiveValue, Condition, ConnectionTrait, DbBackend, DbErr, DeriveIden,
14 DynIden, EntityTrait, FromQueryResult, Iterable, QueryFilter, Schema, Statement,
15 TransactionTrait,
16};
17use sea_schema::{mysql::MySql, postgres::Postgres, probe::SchemaProbe, sqlite::Sqlite};
18
19use super::{seaql_migrations, IntoSchemaManagerConnection, MigrationTrait, SchemaManager};
20
21#[derive(Copy, Clone, Debug, PartialEq, Eq)]
22pub enum MigrationStatus {
24 Pending,
26 Applied,
28}
29
30impl Display for MigrationStatus {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 let status = match self {
33 MigrationStatus::Pending => "Pending",
34 MigrationStatus::Applied => "Applied",
35 };
36 write!(f, "{status}")
37 }
38}
39
40pub struct Migration {
41 migration: Box<dyn MigrationTrait>,
42 status: MigrationStatus,
43}
44
45impl Migration {
46 pub fn name(&self) -> &str {
48 self.migration.name()
49 }
50
51 pub fn status(&self) -> MigrationStatus {
53 self.status
54 }
55}
56
57#[async_trait::async_trait]
59pub trait MigratorTrait: Send {
60 fn migrations() -> Vec<Box<dyn MigrationTrait>>;
62
63 fn migration_table_name() -> DynIden {
65 seaql_migrations::Entity.into_iden()
66 }
67
68 fn get_migration_files() -> Vec<Migration> {
70 Self::migrations()
71 .into_iter()
72 .map(|migration| Migration {
73 migration,
74 status: MigrationStatus::Pending,
75 })
76 .collect()
77 }
78
79 async fn get_migration_models<C>(db: &C) -> Result<Vec<seaql_migrations::Model>, DbErr>
81 where
82 C: ConnectionTrait,
83 {
84 Self::install(db).await?;
85 let stmt = Query::select()
86 .table_name(Self::migration_table_name())
87 .columns(seaql_migrations::Column::iter().map(IntoIden::into_iden))
88 .order_by(seaql_migrations::Column::Version, Order::Asc)
89 .to_owned();
90 let builder = db.get_database_backend();
91 seaql_migrations::Model::find_by_statement(builder.build(&stmt))
92 .all(db)
93 .await
94 }
95
96 async fn get_migration_with_status<C>(db: &C) -> Result<Vec<Migration>, DbErr>
98 where
99 C: ConnectionTrait,
100 {
101 Self::install(db).await?;
102 let mut migration_files = Self::get_migration_files();
103 let migration_models = Self::get_migration_models(db).await?;
104
105 let migration_in_db: HashSet<String> = migration_models
106 .into_iter()
107 .map(|model| model.version)
108 .collect();
109 let migration_in_fs: HashSet<String> = migration_files
110 .iter()
111 .map(|file| file.migration.name().to_string())
112 .collect();
113
114 let pending_migrations = &migration_in_fs - &migration_in_db;
115 for migration_file in migration_files.iter_mut() {
116 if !pending_migrations.contains(migration_file.migration.name()) {
117 migration_file.status = MigrationStatus::Applied;
118 }
119 }
120
121 let missing_migrations_in_fs = &migration_in_db - &migration_in_fs;
122 let errors: Vec<String> = missing_migrations_in_fs
123 .iter()
124 .map(|missing_migration| {
125 format!("Migration file of version '{missing_migration}' is missing, this migration has been applied but its file is missing")
126 }).collect();
127
128 if !errors.is_empty() {
129 Err(DbErr::Custom(errors.join("\n")))
130 } else {
131 Ok(migration_files)
132 }
133 }
134
135 async fn get_pending_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
137 where
138 C: ConnectionTrait,
139 {
140 Self::install(db).await?;
141 Ok(Self::get_migration_with_status(db)
142 .await?
143 .into_iter()
144 .filter(|file| file.status == MigrationStatus::Pending)
145 .collect())
146 }
147
148 async fn get_applied_migrations<C>(db: &C) -> Result<Vec<Migration>, DbErr>
150 where
151 C: ConnectionTrait,
152 {
153 Self::install(db).await?;
154 Ok(Self::get_migration_with_status(db)
155 .await?
156 .into_iter()
157 .filter(|file| file.status == MigrationStatus::Applied)
158 .collect())
159 }
160
161 async fn install<C>(db: &C) -> Result<(), DbErr>
163 where
164 C: ConnectionTrait,
165 {
166 let builder = db.get_database_backend();
167 let table_name = Self::migration_table_name();
168 let schema = Schema::new(builder);
169 let mut stmt = schema
170 .create_table_from_entity(seaql_migrations::Entity)
171 .table_name(table_name);
172 stmt.if_not_exists();
173 db.execute(builder.build(&stmt)).await.map(|_| ())
174 }
175
176 async fn status<C>(db: &C) -> Result<(), DbErr>
178 where
179 C: ConnectionTrait,
180 {
181 Self::install(db).await?;
182
183 info!("Checking migration status");
184
185 for Migration { migration, status } in Self::get_migration_with_status(db).await? {
186 info!("Migration '{}'... {}", migration.name(), status);
187 }
188
189 Ok(())
190 }
191
192 async fn fresh<'c, C>(db: C) -> Result<(), DbErr>
194 where
195 C: IntoSchemaManagerConnection<'c>,
196 {
197 exec_with_connection::<'_, _, _>(db, move |manager| {
198 Box::pin(async move { exec_fresh::<Self>(manager).await })
199 })
200 .await
201 }
202
203 async fn refresh<'c, C>(db: C) -> Result<(), DbErr>
205 where
206 C: IntoSchemaManagerConnection<'c>,
207 {
208 exec_with_connection::<'_, _, _>(db, move |manager| {
209 Box::pin(async move {
210 exec_down::<Self>(manager, None).await?;
211 exec_up::<Self>(manager, None).await
212 })
213 })
214 .await
215 }
216
217 async fn reset<'c, C>(db: C) -> Result<(), DbErr>
219 where
220 C: IntoSchemaManagerConnection<'c>,
221 {
222 exec_with_connection::<'_, _, _>(db, move |manager| {
223 Box::pin(async move { exec_down::<Self>(manager, None).await })
224 })
225 .await
226 }
227
228 async fn up<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
230 where
231 C: IntoSchemaManagerConnection<'c>,
232 {
233 exec_with_connection::<'_, _, _>(db, move |manager| {
234 Box::pin(async move { exec_up::<Self>(manager, steps).await })
235 })
236 .await
237 }
238
239 async fn down<'c, C>(db: C, steps: Option<u32>) -> Result<(), DbErr>
241 where
242 C: IntoSchemaManagerConnection<'c>,
243 {
244 exec_with_connection::<'_, _, _>(db, move |manager| {
245 Box::pin(async move { exec_down::<Self>(manager, steps).await })
246 })
247 .await
248 }
249}
250
251async fn exec_with_connection<'c, C, F>(db: C, f: F) -> Result<(), DbErr>
252where
253 C: IntoSchemaManagerConnection<'c>,
254 F: for<'b> Fn(
255 &'b SchemaManager<'_>,
256 ) -> Pin<Box<dyn Future<Output = Result<(), DbErr>> + Send + 'b>>,
257{
258 let db = db.into_schema_manager_connection();
259
260 match db.get_database_backend() {
261 DbBackend::Postgres => {
262 let transaction = db.begin().await?;
263 let manager = SchemaManager::new(&transaction);
264 f(&manager).await?;
265 transaction.commit().await
266 }
267 DbBackend::MySql | DbBackend::Sqlite => {
268 let manager = SchemaManager::new(db);
269 f(&manager).await
270 }
271 }
272}
273
274async fn exec_fresh<M>(manager: &SchemaManager<'_>) -> Result<(), DbErr>
275where
276 M: MigratorTrait + ?Sized,
277{
278 let db = manager.get_connection();
279
280 M::install(db).await?;
281 let db_backend = db.get_database_backend();
282
283 if db_backend == DbBackend::Sqlite {
285 info!("Disabling foreign key check");
286 db.execute(Statement::from_string(
287 db_backend,
288 "PRAGMA foreign_keys = OFF".to_owned(),
289 ))
290 .await?;
291 info!("Foreign key check disabled");
292 }
293
294 if db_backend == DbBackend::MySql {
296 info!("Dropping all foreign keys");
297 let stmt = query_mysql_foreign_keys(db);
298 let rows = db.query_all(db_backend.build(&stmt)).await?;
299 for row in rows.into_iter() {
300 let constraint_name: String = row.try_get("", "CONSTRAINT_NAME")?;
301 let table_name: String = row.try_get("", "TABLE_NAME")?;
302 info!(
303 "Dropping foreign key '{}' from table '{}'",
304 constraint_name, table_name
305 );
306 let mut stmt = ForeignKey::drop();
307 stmt.table(Alias::new(table_name.as_str()))
308 .name(constraint_name.as_str());
309 db.execute(db_backend.build(&stmt)).await?;
310 info!("Foreign key '{}' has been dropped", constraint_name);
311 }
312 info!("All foreign keys dropped");
313 }
314
315 let stmt = query_tables(db).await;
317 let rows = db.query_all(db_backend.build(&stmt)).await?;
318 for row in rows.into_iter() {
319 let table_name: String = row.try_get("", "table_name")?;
320 info!("Dropping table '{}'", table_name);
321 let mut stmt = Table::drop();
322 stmt.table(Alias::new(table_name.as_str()))
323 .if_exists()
324 .cascade();
325 db.execute(db_backend.build(&stmt)).await?;
326 info!("Table '{}' has been dropped", table_name);
327 }
328
329 if db_backend == DbBackend::Postgres {
331 info!("Dropping all types");
332 let stmt = query_pg_types(db);
333 let rows = db.query_all(db_backend.build(&stmt)).await?;
334 for row in rows {
335 let type_name: String = row.try_get("", "typname")?;
336 info!("Dropping type '{}'", type_name);
337 let mut stmt = Type::drop();
338 stmt.name(Alias::new(&type_name));
339 db.execute(db_backend.build(&stmt)).await?;
340 info!("Type '{}' has been dropped", type_name);
341 }
342 }
343
344 if db_backend == DbBackend::Sqlite {
346 info!("Restoring foreign key check");
347 db.execute(Statement::from_string(
348 db_backend,
349 "PRAGMA foreign_keys = ON".to_owned(),
350 ))
351 .await?;
352 info!("Foreign key check restored");
353 }
354
355 exec_up::<M>(manager, None).await
357}
358
359async fn exec_up<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
360where
361 M: MigratorTrait + ?Sized,
362{
363 let db = manager.get_connection();
364
365 M::install(db).await?;
366
367 if let Some(steps) = steps {
368 info!("Applying {} pending migrations", steps);
369 } else {
370 info!("Applying all pending migrations");
371 }
372
373 let migrations = M::get_pending_migrations(db).await?.into_iter();
374 if migrations.len() == 0 {
375 info!("No pending migrations");
376 }
377 for Migration { migration, .. } in migrations {
378 if let Some(steps) = steps.as_mut() {
379 if steps == &0 {
380 break;
381 }
382 *steps -= 1;
383 }
384 info!("Applying migration '{}'", migration.name());
385 migration.up(manager).await?;
386 info!("Migration '{}' has been applied", migration.name());
387 let now = SystemTime::now()
388 .duration_since(SystemTime::UNIX_EPOCH)
389 .expect("SystemTime before UNIX EPOCH!");
390 seaql_migrations::Entity::insert(seaql_migrations::ActiveModel {
391 version: ActiveValue::Set(migration.name().to_owned()),
392 applied_at: ActiveValue::Set(now.as_secs() as i64),
393 })
394 .table_name(M::migration_table_name())
395 .exec(db)
396 .await?;
397 }
398
399 Ok(())
400}
401
402async fn exec_down<M>(manager: &SchemaManager<'_>, mut steps: Option<u32>) -> Result<(), DbErr>
403where
404 M: MigratorTrait + ?Sized,
405{
406 let db = manager.get_connection();
407
408 M::install(db).await?;
409
410 if let Some(steps) = steps {
411 info!("Rolling back {} applied migrations", steps);
412 } else {
413 info!("Rolling back all applied migrations");
414 }
415
416 let migrations = M::get_applied_migrations(db).await?.into_iter().rev();
417 if migrations.len() == 0 {
418 info!("No applied migrations");
419 }
420 for Migration { migration, .. } in migrations {
421 if let Some(steps) = steps.as_mut() {
422 if steps == &0 {
423 break;
424 }
425 *steps -= 1;
426 }
427 info!("Rolling back migration '{}'", migration.name());
428 migration.down(manager).await?;
429 info!("Migration '{}' has been rollbacked", migration.name());
430 seaql_migrations::Entity::delete_many()
431 .filter(Expr::col(seaql_migrations::Column::Version).eq(migration.name()))
432 .table_name(M::migration_table_name())
433 .exec(db)
434 .await?;
435 }
436
437 Ok(())
438}
439
440async fn query_tables<C>(db: &C) -> SelectStatement
441where
442 C: ConnectionTrait,
443{
444 match db.get_database_backend() {
445 DbBackend::MySql => MySql.query_tables(),
446 DbBackend::Postgres => Postgres.query_tables(),
447 DbBackend::Sqlite => Sqlite.query_tables(),
448 }
449}
450
451fn get_current_schema<C>(db: &C) -> SimpleExpr
452where
453 C: ConnectionTrait,
454{
455 match db.get_database_backend() {
456 DbBackend::MySql => MySql::get_current_schema(),
457 DbBackend::Postgres => Postgres::get_current_schema(),
458 DbBackend::Sqlite => unimplemented!(),
459 }
460}
461
462#[derive(DeriveIden)]
463enum InformationSchema {
464 #[sea_orm(iden = "information_schema")]
465 Schema,
466 #[sea_orm(iden = "TABLE_NAME")]
467 TableName,
468 #[sea_orm(iden = "CONSTRAINT_NAME")]
469 ConstraintName,
470 TableConstraints,
471 TableSchema,
472 ConstraintType,
473}
474
475fn query_mysql_foreign_keys<C>(db: &C) -> SelectStatement
476where
477 C: ConnectionTrait,
478{
479 let mut stmt = Query::select();
480 stmt.columns([
481 InformationSchema::TableName,
482 InformationSchema::ConstraintName,
483 ])
484 .from((
485 InformationSchema::Schema,
486 InformationSchema::TableConstraints,
487 ))
488 .cond_where(
489 Condition::all()
490 .add(Expr::expr(get_current_schema(db)).equals((
491 InformationSchema::TableConstraints,
492 InformationSchema::TableSchema,
493 )))
494 .add(
495 Expr::col((
496 InformationSchema::TableConstraints,
497 InformationSchema::ConstraintType,
498 ))
499 .eq("FOREIGN KEY"),
500 ),
501 );
502 stmt
503}
504
505#[derive(DeriveIden)]
506enum PgType {
507 Table,
508 Typname,
509 Typnamespace,
510 Typelem,
511}
512
513#[derive(DeriveIden)]
514enum PgNamespace {
515 Table,
516 Oid,
517 Nspname,
518}
519
520fn query_pg_types<C>(db: &C) -> SelectStatement
521where
522 C: ConnectionTrait,
523{
524 let mut stmt = Query::select();
525 stmt.column(PgType::Typname)
526 .from(PgType::Table)
527 .join(
528 JoinType::LeftJoin,
529 PgNamespace::Table,
530 Expr::col((PgNamespace::Table, PgNamespace::Oid))
531 .equals((PgType::Table, PgType::Typnamespace)),
532 )
533 .cond_where(
534 Condition::all()
535 .add(
536 Expr::expr(get_current_schema(db))
537 .equals((PgNamespace::Table, PgNamespace::Nspname)),
538 )
539 .add(Expr::col((PgType::Table, PgType::Typelem)).eq(0)),
540 );
541 stmt
542}
543
544trait QueryTable {
545 type Statement;
546
547 fn table_name(self, table_name: DynIden) -> Self::Statement;
548}
549
550impl QueryTable for SelectStatement {
551 type Statement = SelectStatement;
552
553 fn table_name(mut self, table_name: DynIden) -> SelectStatement {
554 self.from(table_name);
555 self
556 }
557}
558
559impl QueryTable for sea_query::TableCreateStatement {
560 type Statement = sea_query::TableCreateStatement;
561
562 fn table_name(mut self, table_name: DynIden) -> sea_query::TableCreateStatement {
563 self.table(table_name);
564 self
565 }
566}
567
568impl<A> QueryTable for sea_orm::Insert<A>
569where
570 A: ActiveModelTrait,
571{
572 type Statement = sea_orm::Insert<A>;
573
574 fn table_name(mut self, table_name: DynIden) -> sea_orm::Insert<A> {
575 sea_orm::QueryTrait::query(&mut self).into_table(table_name);
576 self
577 }
578}
579
580impl<E> QueryTable for sea_orm::DeleteMany<E>
581where
582 E: EntityTrait,
583{
584 type Statement = sea_orm::DeleteMany<E>;
585
586 fn table_name(mut self, table_name: DynIden) -> sea_orm::DeleteMany<E> {
587 sea_orm::QueryTrait::query(&mut self).from_table(table_name);
588 self
589 }
590}