1use super::{Schema, TopologicalSort, entity::index_table_ref};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4 ForeignKeyCreateStatement, Index, IndexCreateStatement, IntoIden, TableAlterStatement,
5 TableCreateStatement, TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8pub struct SchemaBuilder {
10 helper: Schema,
11 entities: Vec<EntitySchemaInfo>,
12}
13
14pub struct EntitySchemaInfo {
16 table: TableCreateStatement,
17 enums: Vec<TypeCreateStatement>,
18 indexes: Vec<IndexCreateStatement>,
19 schema_name: Option<String>,
22}
23
24impl std::fmt::Debug for SchemaBuilder {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 write!(f, "SchemaBuilder {{")?;
27 write!(f, " entities: [")?;
28 for (i, entity) in self.entities.iter().enumerate() {
29 if i > 0 {
30 write!(f, ", ")?;
31 }
32 entity.debug_print(f, &self.helper.backend)?;
33 }
34 write!(f, " ]")?;
35 write!(f, " }}")
36 }
37}
38
39impl std::fmt::Debug for EntitySchemaInfo {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 self.debug_print(f, &DbBackend::Sqlite)
42 }
43}
44
45impl SchemaBuilder {
46 pub fn new(schema: Schema) -> Self {
48 Self {
49 helper: schema,
50 entities: Default::default(),
51 }
52 }
53
54 pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
56 let entity = EntitySchemaInfo::new(entity, &self.helper);
57 if !self
58 .entities
59 .iter()
60 .any(|e| e.table.get_table_name() == entity.table.get_table_name())
61 {
62 self.entities.push(entity);
63 }
64 self
65 }
66
67 #[cfg(feature = "entity-registry")]
68 pub(crate) fn helper(&self) -> &Schema {
69 &self.helper
70 }
71
72 #[cfg(feature = "entity-registry")]
73 pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
74 self.entities.push(entity);
75 }
76
77 #[cfg(feature = "schema-sync")]
80 #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
81 pub fn sync<C>(self, db: &C) -> Result<(), DbErr>
82 where
83 C: ConnectionTrait + sea_schema::Connection,
84 {
85 let _existing = match db.get_database_backend() {
86 #[cfg(feature = "sqlx-mysql")]
87 DbBackend::MySql => {
88 use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
89
90 let current_schema: String = db
91 .query_one(
92 sea_query::SelectStatement::new()
93 .expr(sea_schema::mysql::MySql::get_current_schema()),
94 )?
95 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
96 .try_get_by_index(0)?;
97
98 let mut target_schemas = std::collections::BTreeSet::new();
100 for entity in &self.entities {
101 let schema = entity.schema_name.as_deref().unwrap_or(¤t_schema);
102 target_schemas.insert(schema.to_string());
103 }
104
105 let mut tables_by_schema = std::collections::HashMap::new();
106 for schema_name in &target_schemas {
107 let schema_discovery = SchemaDiscovery::new_no_exec(schema_name);
108 let schema = schema_discovery
109 .discover_with(db)
110 .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
111
112 tables_by_schema.insert(
113 schema_name.clone(),
114 schema.tables.iter().map(|table| table.write()).collect(),
115 );
116 }
117
118 DiscoveredSchema {
119 current_schema,
120 tables_by_schema,
121 enums_by_schema: Default::default(),
122 }
123 }
124 #[cfg(feature = "sqlx-postgres")]
125 DbBackend::Postgres => {
126 use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
127
128 let current_schema: String = db
129 .query_one(
130 sea_query::SelectStatement::new()
131 .expr(sea_schema::postgres::Postgres::get_current_schema()),
132 )?
133 .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
134 .try_get_by_index(0)?;
135
136 let mut target_schemas = std::collections::BTreeSet::new();
138 for entity in &self.entities {
139 let schema = entity.schema_name.as_deref().unwrap_or(¤t_schema);
140 target_schemas.insert(schema.to_string());
141 }
142
143 let mut tables_by_schema = std::collections::HashMap::new();
144 let mut enums_by_schema = std::collections::HashMap::new();
145 for schema_name in &target_schemas {
146 let schema_discovery = SchemaDiscovery::new_no_exec(schema_name);
147 let schema = schema_discovery
148 .discover_with(db)
149 .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
150
151 tables_by_schema.insert(
152 schema_name.clone(),
153 schema.tables.iter().map(|table| table.write()).collect(),
154 );
155 enums_by_schema.insert(
156 schema_name.clone(),
157 schema.enums.iter().map(|def| def.write()).collect(),
158 );
159 }
160
161 DiscoveredSchema {
162 current_schema,
163 tables_by_schema,
164 enums_by_schema,
165 }
166 }
167 #[cfg(feature = "sqlx-sqlite")]
168 DbBackend::Sqlite => {
169 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
170 let schema = SchemaDiscovery::discover_with(db)
171 .map_err(|err| {
172 DbErr::Query(match err {
173 SqliteDiscoveryError::SqlxError(err) => {
174 crate::RuntimeErr::SqlxError(err.into())
175 }
176 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
177 })
178 })?
179 .merge_indexes_into_table();
180 let mut tables_by_schema = std::collections::HashMap::new();
181 tables_by_schema.insert(
182 String::new(),
183 schema.tables.iter().map(|table| table.write()).collect(),
184 );
185 DiscoveredSchema {
186 current_schema: String::new(),
187 tables_by_schema,
188 enums_by_schema: Default::default(),
189 }
190 }
191 #[cfg(feature = "rusqlite")]
192 DbBackend::Sqlite => {
193 use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
194 let schema = SchemaDiscovery::discover_with(db)
195 .map_err(|err| {
196 DbErr::Query(match err {
197 SqliteDiscoveryError::RusqliteError(err) => {
198 crate::RuntimeErr::Rusqlite(err.into())
199 }
200 _ => crate::RuntimeErr::Internal(format!("{err:?}")),
201 })
202 })?
203 .merge_indexes_into_table();
204 let mut tables_by_schema = std::collections::HashMap::new();
205 tables_by_schema.insert(
206 String::new(),
207 schema.tables.iter().map(|table| table.write()).collect(),
208 );
209 DiscoveredSchema {
210 current_schema: String::new(),
211 tables_by_schema,
212 enums_by_schema: Default::default(),
213 }
214 }
215 #[allow(unreachable_patterns)]
216 other => {
217 return Err(DbErr::BackendNotSupported {
218 db: other.as_str(),
219 ctx: "SchemaBuilder::sync",
220 });
221 }
222 };
223
224 #[allow(unreachable_code)]
225 let mut created_enums: Vec<Statement> = Default::default();
226
227 #[allow(unreachable_code)]
228 for table_name in self.sorted_tables() {
229 if let Some(entity) = self
230 .entities
231 .iter()
232 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
233 {
234 entity.sync(db, &_existing, &mut created_enums)?;
235 }
236 }
237
238 Ok(())
239 }
240
241 pub fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
245 let mut created_enums: Vec<Statement> = Default::default();
246
247 for table_name in self.sorted_tables() {
248 if let Some(entity) = self
249 .entities
250 .iter()
251 .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
252 {
253 entity.apply(db, &mut created_enums)?;
254 }
255 }
256
257 Ok(())
258 }
259
260 fn sorted_tables(&self) -> Vec<TableName> {
261 let mut sorter = TopologicalSort::<TableName>::new();
262
263 for entity in self.entities.iter() {
264 let table_name = get_table_name(entity.table.get_table_name());
265 sorter.insert(table_name);
266 }
267 for entity in self.entities.iter() {
268 let self_table = get_table_name(entity.table.get_table_name());
269 for fk in entity.table.get_foreign_key_create_stmts().iter() {
270 let fk = fk.get_foreign_key();
271 let ref_table = get_table_name(fk.get_ref_table());
272 if self_table != ref_table {
273 sorter.add_dependency(ref_table, self_table.clone());
275 }
276 }
277 }
278 let mut sorted = Vec::new();
279 while let Some(i) = sorter.pop() {
280 sorted.push(i);
281 }
282 if sorted.len() != self.entities.len() {
283 for entity in self.entities.iter() {
285 let table_name = get_table_name(entity.table.get_table_name());
286 if !sorted.contains(&table_name) {
287 sorted.push(table_name);
288 }
289 }
290 }
291
292 sorted
293 }
294}
295
296struct DiscoveredSchema {
297 current_schema: String,
299 tables_by_schema: std::collections::HashMap<String, Vec<TableCreateStatement>>,
301 enums_by_schema: std::collections::HashMap<String, Vec<TypeCreateStatement>>,
303}
304
305impl DiscoveredSchema {
306 fn find_table(
315 &self,
316 entity_schema: Option<&str>,
317 entity_table_name: &TableName,
318 ) -> Option<&TableCreateStatement> {
319 let schema = entity_schema.unwrap_or(&self.current_schema);
320 let schema_tables = self.tables_by_schema.get(schema)?;
321 let bare_entity_name = TableName(None, entity_table_name.1.clone());
324 schema_tables
325 .iter()
326 .find(|tbl| get_table_name(tbl.get_table_name()) == bare_entity_name)
327 }
328
329 fn find_enums(&self, entity_schema: Option<&str>) -> &[TypeCreateStatement] {
330 let schema = entity_schema.unwrap_or(&self.current_schema);
331 self.enums_by_schema
332 .get(schema)
333 .map(|v| v.as_slice())
334 .unwrap_or(&[])
335 }
336}
337
338impl EntitySchemaInfo {
339 pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
341 Self {
342 table: helper.create_table_from_entity(entity),
343 enums: helper.create_enum_from_entity(entity),
344 indexes: helper.create_index_from_entity(entity),
345 schema_name: entity.schema_name().map(|s| s.to_string()),
346 }
347 }
348
349 fn apply<C: ConnectionTrait>(
350 &self,
351 db: &C,
352 created_enums: &mut Vec<Statement>,
353 ) -> Result<(), DbErr> {
354 for stmt in self.enums.iter() {
355 let new_stmt = db.get_database_backend().build(stmt);
356 if !created_enums.iter().any(|s| s == &new_stmt) {
357 db.execute(stmt)?;
358 created_enums.push(new_stmt);
359 }
360 }
361 db.execute(&self.table)?;
362 for stmt in self.indexes.iter() {
363 db.execute(stmt)?;
364 }
365 Ok(())
366 }
367
368 #[allow(dead_code)]
370 fn sync<C: ConnectionTrait>(
371 &self,
372 db: &C,
373 existing: &DiscoveredSchema,
374 created_enums: &mut Vec<Statement>,
375 ) -> Result<(), DbErr> {
376 let db_backend = db.get_database_backend();
377
378 let existing_enums = existing.find_enums(self.schema_name.as_deref());
380 for stmt in self.enums.iter() {
381 let mut has_enum = false;
382 let new_stmt = db_backend.build(stmt);
383 for existing_enum in existing_enums {
384 if db_backend.build(existing_enum) == new_stmt {
385 has_enum = true;
386 break;
388 }
389 }
390 if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
391 db.execute(stmt)?;
392 created_enums.push(new_stmt);
393 }
394 }
395 let table_name = get_table_name(self.table.get_table_name());
396 let existing_table = existing.find_table(self.schema_name.as_deref(), &table_name);
398 if let Some(existing_table) = existing_table {
399 for column_def in self.table.get_columns() {
400 let mut column_exists = false;
401 for existing_column in existing_table.get_columns() {
402 if column_def.get_column_name() == existing_column.get_column_name() {
403 column_exists = true;
404 break;
405 }
406 }
407 if !column_exists {
408 let mut renamed_from = "";
409 if let Some(comment) = &column_def.get_column_spec().comment
410 && let Some((_, suffix)) = comment.rsplit_once("renamed_from \"")
411 && let Some((prefix, _)) = suffix.split_once('"')
412 {
413 renamed_from = prefix;
414 }
415 if renamed_from.is_empty() {
416 db.execute(
417 TableAlterStatement::new()
418 .table(self.table.get_table_name().expect("Checked above").clone())
419 .add_column(column_def.to_owned()),
420 )?;
421 } else {
422 db.execute(
423 TableAlterStatement::new()
424 .table(self.table.get_table_name().expect("Checked above").clone())
425 .rename_column(
426 renamed_from.to_owned(),
427 column_def.get_column_name(),
428 ),
429 )?;
430 }
431 }
432 }
433 if db.get_database_backend() != DbBackend::Sqlite {
434 for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
435 let mut key_exists = false;
436 for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
437 if compare_foreign_key(foreign_key, existing_key) {
438 key_exists = true;
439 break;
440 }
441 }
442 if !key_exists {
443 db.execute(foreign_key)?;
444 }
445 }
446 }
447 } else {
448 db.execute(&self.table)?;
449 }
450 for stmt in self.indexes.iter() {
451 let mut has_index = false;
452 if let Some(existing_table) = existing_table {
453 for existing_index in existing_table.get_indexes() {
454 if existing_index.get_index_spec().get_column_names()
455 == stmt.get_index_spec().get_column_names()
456 {
457 has_index = true;
458 break;
459 }
460 }
461 }
462 if !has_index {
463 let mut stmt = stmt.clone();
465 stmt.if_not_exists();
466 db.execute(&stmt)?;
467 }
468 }
469 if let Some(existing_table) = existing_table {
470 for column_def in self.table.get_columns() {
473 if column_def.get_column_spec().unique {
474 let col_name = column_def.get_column_name();
475 let col_exists = existing_table
476 .get_columns()
477 .iter()
478 .any(|c| c.get_column_name() == col_name);
479 if !col_exists {
480 continue;
483 }
484 let already_unique = existing_table.get_indexes().iter().any(|idx| {
485 if !idx.is_unique_key() {
486 return false;
487 }
488 let cols = idx.get_index_spec().get_column_names();
489 cols.len() == 1 && cols[0] == col_name
490 });
491 if !already_unique {
492 let table_name =
493 self.table.get_table_name().expect("table must have a name");
494 let tbl_str = table_name.sea_orm_table().to_string();
495 let table_ref = index_table_ref(table_name.clone(), db_backend);
496 db.execute(
497 Index::create()
498 .name(format!("idx-{tbl_str}-{col_name}"))
499 .table(table_ref)
500 .col(col_name.into_iden())
501 .unique()
502 .if_not_exists(),
503 )?;
504 }
505 }
506 }
507 }
508 if let Some(existing_table) = existing_table {
509 for existing_index in existing_table.get_indexes() {
512 if existing_index.is_unique_key() {
513 let mut has_index = false;
514 for stmt in self.indexes.iter() {
515 if existing_index.get_index_spec().get_column_names()
516 == stmt.get_index_spec().get_column_names()
517 {
518 has_index = true;
519 break;
520 }
521 }
522 if !has_index {
527 let index_cols = existing_index.get_index_spec().get_column_names();
528 if index_cols.len() == 1 {
529 for column_def in self.table.get_columns() {
530 if column_def.get_column_name() == index_cols[0]
531 && column_def.get_column_spec().unique
532 {
533 has_index = true;
534 break;
535 }
536 }
537 }
538 }
539 if !has_index
540 && let Some(drop_existing) = existing_index
541 .get_index_spec()
542 .get_name()
543 .map(|s| s.to_owned())
544 {
545 if db_backend == DbBackend::Postgres {
546 db.execute(
551 TableAlterStatement::new()
552 .table(
553 self.table.get_table_name().expect("Checked above").clone(),
554 )
555 .drop_constraint(drop_existing),
556 )?;
557 } else {
558 db.execute(sea_query::Index::drop().name(drop_existing))?;
559 }
560 }
561 }
562 }
563 }
564 Ok(())
565 }
566
567 fn debug_print(
568 &self,
569 f: &mut std::fmt::Formatter<'_>,
570 backend: &DbBackend,
571 ) -> std::fmt::Result {
572 write!(f, "EntitySchemaInfo {{")?;
573 write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
574 write!(f, " enums: [")?;
575 for (i, stmt) in self.enums.iter().enumerate() {
576 if i > 0 {
577 write!(f, ", ")?;
578 }
579 write!(f, "{:?}", backend.build(stmt).to_string())?;
580 }
581 write!(f, " ]")?;
582 write!(f, " indexes: [")?;
583 for (i, stmt) in self.indexes.iter().enumerate() {
584 if i > 0 {
585 write!(f, ", ")?;
586 }
587 write!(f, "{:?}", backend.build(stmt).to_string())?;
588 }
589 write!(f, " ]")?;
590 write!(f, " }}")
591 }
592}
593
594fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
595 match table_ref {
596 Some(TableRef::Table(table_name, _)) => table_name.clone(),
597 None => panic!("Expect TableCreateStatement is properly built"),
598 _ => unreachable!("Unexpected {table_ref:?}"),
599 }
600}
601
602fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
603 let a = a.get_foreign_key();
604 let b = b.get_foreign_key();
605
606 a.get_name() == b.get_name()
607 || (a.get_ref_table() == b.get_ref_table()
608 && a.get_columns() == b.get_columns()
609 && a.get_ref_columns() == b.get_ref_columns())
610}