Skip to main content

sea_orm/schema/
builder.rs

1use super::{Schema, TopologicalSort, entity::index_table_ref};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4    ForeignKeyCreateStatement, Index, IndexCreateStatement, IntoIden, TableAlterStatement,
5    TableCreateStatement, TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8/// A schema builder that can take a registry of Entities and synchronize it with database.
9pub struct SchemaBuilder {
10    helper: Schema,
11    entities: Vec<EntitySchemaInfo>,
12}
13
14/// Schema info for Entity. Can be used to re-create schema in database.
15pub struct EntitySchemaInfo {
16    table: TableCreateStatement,
17    enums: Vec<TypeCreateStatement>,
18    indexes: Vec<IndexCreateStatement>,
19    /// The schema name from the entity definition (e.g., `#[sea_orm(schema_name = "sys")]`).
20    /// `None` means the entity uses the database's current/default schema.
21    schema_name: Option<String>,
22}
23
24impl std::fmt::Debug for SchemaBuilder {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        write!(f, "SchemaBuilder {{")?;
27        write!(f, " entities: [")?;
28        for (i, entity) in self.entities.iter().enumerate() {
29            if i > 0 {
30                write!(f, ", ")?;
31            }
32            entity.debug_print(f, &self.helper.backend)?;
33        }
34        write!(f, " ]")?;
35        write!(f, " }}")
36    }
37}
38
39impl std::fmt::Debug for EntitySchemaInfo {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        self.debug_print(f, &DbBackend::Sqlite)
42    }
43}
44
45impl SchemaBuilder {
46    /// Creates a new schema builder
47    pub fn new(schema: Schema) -> Self {
48        Self {
49            helper: schema,
50            entities: Default::default(),
51        }
52    }
53
54    /// Register an entity to this schema
55    pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
56        let entity = EntitySchemaInfo::new(entity, &self.helper);
57        if !self
58            .entities
59            .iter()
60            .any(|e| e.table.get_table_name() == entity.table.get_table_name())
61        {
62            self.entities.push(entity);
63        }
64        self
65    }
66
67    #[cfg(feature = "entity-registry")]
68    pub(crate) fn helper(&self) -> &Schema {
69        &self.helper
70    }
71
72    #[cfg(feature = "entity-registry")]
73    pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
74        self.entities.push(entity);
75    }
76
77    /// Synchronize the schema with database, will create missing tables, columns, unique keys, and foreign keys.
78    /// This operation is addition only, will not drop any table / columns.
79    #[cfg(feature = "schema-sync")]
80    #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
81    pub fn sync<C>(self, db: &C) -> Result<(), DbErr>
82    where
83        C: ConnectionTrait + sea_schema::Connection,
84    {
85        let _existing = match db.get_database_backend() {
86            #[cfg(feature = "sqlx-mysql")]
87            DbBackend::MySql => {
88                use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
89
90                let current_schema: String = db
91                    .query_one(
92                        sea_query::SelectStatement::new()
93                            .expr(sea_schema::mysql::MySql::get_current_schema()),
94                    )?
95                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
96                    .try_get_by_index(0)?;
97
98                // Collect all unique schemas that registered entities belong to
99                let mut target_schemas = std::collections::BTreeSet::new();
100                for entity in &self.entities {
101                    let schema = entity.schema_name.as_deref().unwrap_or(&current_schema);
102                    target_schemas.insert(schema.to_string());
103                }
104
105                let mut tables_by_schema = std::collections::HashMap::new();
106                for schema_name in &target_schemas {
107                    let schema_discovery = SchemaDiscovery::new_no_exec(schema_name);
108                    let schema = schema_discovery
109                        .discover_with(db)
110                        .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
111
112                    tables_by_schema.insert(
113                        schema_name.clone(),
114                        schema.tables.iter().map(|table| table.write()).collect(),
115                    );
116                }
117
118                DiscoveredSchema {
119                    current_schema,
120                    tables_by_schema,
121                    enums_by_schema: Default::default(),
122                }
123            }
124            #[cfg(feature = "sqlx-postgres")]
125            DbBackend::Postgres => {
126                use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
127
128                let current_schema: String = db
129                    .query_one(
130                        sea_query::SelectStatement::new()
131                            .expr(sea_schema::postgres::Postgres::get_current_schema()),
132                    )?
133                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
134                    .try_get_by_index(0)?;
135
136                // Collect all unique schemas that registered entities belong to
137                let mut target_schemas = std::collections::BTreeSet::new();
138                for entity in &self.entities {
139                    let schema = entity.schema_name.as_deref().unwrap_or(&current_schema);
140                    target_schemas.insert(schema.to_string());
141                }
142
143                let mut tables_by_schema = std::collections::HashMap::new();
144                let mut enums_by_schema = std::collections::HashMap::new();
145                for schema_name in &target_schemas {
146                    let schema_discovery = SchemaDiscovery::new_no_exec(schema_name);
147                    let schema = schema_discovery
148                        .discover_with(db)
149                        .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
150
151                    tables_by_schema.insert(
152                        schema_name.clone(),
153                        schema.tables.iter().map(|table| table.write()).collect(),
154                    );
155                    enums_by_schema.insert(
156                        schema_name.clone(),
157                        schema.enums.iter().map(|def| def.write()).collect(),
158                    );
159                }
160
161                DiscoveredSchema {
162                    current_schema,
163                    tables_by_schema,
164                    enums_by_schema,
165                }
166            }
167            #[cfg(feature = "sqlx-sqlite")]
168            DbBackend::Sqlite => {
169                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
170                let schema = SchemaDiscovery::discover_with(db)
171                    .map_err(|err| {
172                        DbErr::Query(match err {
173                            SqliteDiscoveryError::SqlxError(err) => {
174                                crate::RuntimeErr::SqlxError(err.into())
175                            }
176                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
177                        })
178                    })?
179                    .merge_indexes_into_table();
180                let mut tables_by_schema = std::collections::HashMap::new();
181                tables_by_schema.insert(
182                    String::new(),
183                    schema.tables.iter().map(|table| table.write()).collect(),
184                );
185                DiscoveredSchema {
186                    current_schema: String::new(),
187                    tables_by_schema,
188                    enums_by_schema: Default::default(),
189                }
190            }
191            #[cfg(feature = "rusqlite")]
192            DbBackend::Sqlite => {
193                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
194                let schema = SchemaDiscovery::discover_with(db)
195                    .map_err(|err| {
196                        DbErr::Query(match err {
197                            SqliteDiscoveryError::RusqliteError(err) => {
198                                crate::RuntimeErr::Rusqlite(err.into())
199                            }
200                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
201                        })
202                    })?
203                    .merge_indexes_into_table();
204                let mut tables_by_schema = std::collections::HashMap::new();
205                tables_by_schema.insert(
206                    String::new(),
207                    schema.tables.iter().map(|table| table.write()).collect(),
208                );
209                DiscoveredSchema {
210                    current_schema: String::new(),
211                    tables_by_schema,
212                    enums_by_schema: Default::default(),
213                }
214            }
215            #[allow(unreachable_patterns)]
216            other => {
217                return Err(DbErr::BackendNotSupported {
218                    db: other.as_str(),
219                    ctx: "SchemaBuilder::sync",
220                });
221            }
222        };
223
224        #[allow(unreachable_code)]
225        let mut created_enums: Vec<Statement> = Default::default();
226
227        #[allow(unreachable_code)]
228        for table_name in self.sorted_tables() {
229            if let Some(entity) = self
230                .entities
231                .iter()
232                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
233            {
234                entity.sync(db, &_existing, &mut created_enums)?;
235            }
236        }
237
238        Ok(())
239    }
240
241    /// Create all registered tables, columns, unique keys, and foreign keys.
242    /// Fails if any table already exists. Use `sync` (feature `schema-sync`)
243    /// instead for an incremental version that diffs against the live schema.
244    pub fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
245        let mut created_enums: Vec<Statement> = Default::default();
246
247        for table_name in self.sorted_tables() {
248            if let Some(entity) = self
249                .entities
250                .iter()
251                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
252            {
253                entity.apply(db, &mut created_enums)?;
254            }
255        }
256
257        Ok(())
258    }
259
260    fn sorted_tables(&self) -> Vec<TableName> {
261        let mut sorter = TopologicalSort::<TableName>::new();
262
263        for entity in self.entities.iter() {
264            let table_name = get_table_name(entity.table.get_table_name());
265            sorter.insert(table_name);
266        }
267        for entity in self.entities.iter() {
268            let self_table = get_table_name(entity.table.get_table_name());
269            for fk in entity.table.get_foreign_key_create_stmts().iter() {
270                let fk = fk.get_foreign_key();
271                let ref_table = get_table_name(fk.get_ref_table());
272                if self_table != ref_table {
273                    // self cycle is okay
274                    sorter.add_dependency(ref_table, self_table.clone());
275                }
276            }
277        }
278        let mut sorted = Vec::new();
279        while let Some(i) = sorter.pop() {
280            sorted.push(i);
281        }
282        if sorted.len() != self.entities.len() {
283            // push leftover tables
284            for entity in self.entities.iter() {
285                let table_name = get_table_name(entity.table.get_table_name());
286                if !sorted.contains(&table_name) {
287                    sorted.push(table_name);
288                }
289            }
290        }
291
292        sorted
293    }
294}
295
296struct DiscoveredSchema {
297    /// The current/default schema of the database connection (e.g., "public" for Postgres).
298    current_schema: String,
299    /// Tables discovered from the database, grouped by schema name.
300    tables_by_schema: std::collections::HashMap<String, Vec<TableCreateStatement>>,
301    /// Enums discovered from the database, grouped by schema name.
302    enums_by_schema: std::collections::HashMap<String, Vec<TypeCreateStatement>>,
303}
304
305impl DiscoveredSchema {
306    /// Find an existing table in the discovered schema that matches the given entity.
307    ///
308    /// `entity_schema` is the entity's explicit schema_name (from `#[sea_orm(schema_name = "...")]`).
309    /// If `None`, the entity uses the database's current/default schema.
310    ///
311    /// The comparison uses bare table names (without schema qualifiers) because
312    /// `sea-schema` discovery results do not include schema information in the
313    /// `TableCreateStatement`.
314    fn find_table(
315        &self,
316        entity_schema: Option<&str>,
317        entity_table_name: &TableName,
318    ) -> Option<&TableCreateStatement> {
319        let schema = entity_schema.unwrap_or(&self.current_schema);
320        let schema_tables = self.tables_by_schema.get(schema)?;
321        // Strip schema from entity table name for comparison, because discovered
322        // tables from sea-schema do not carry schema qualifiers.
323        let bare_entity_name = TableName(None, entity_table_name.1.clone());
324        schema_tables
325            .iter()
326            .find(|tbl| get_table_name(tbl.get_table_name()) == bare_entity_name)
327    }
328
329    fn find_enums(&self, entity_schema: Option<&str>) -> &[TypeCreateStatement] {
330        let schema = entity_schema.unwrap_or(&self.current_schema);
331        self.enums_by_schema
332            .get(schema)
333            .map(|v| v.as_slice())
334            .unwrap_or(&[])
335    }
336}
337
338impl EntitySchemaInfo {
339    /// Creates a EntitySchemaInfo object given a generic Entity.
340    pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
341        Self {
342            table: helper.create_table_from_entity(entity),
343            enums: helper.create_enum_from_entity(entity),
344            indexes: helper.create_index_from_entity(entity),
345            schema_name: entity.schema_name().map(|s| s.to_string()),
346        }
347    }
348
349    fn apply<C: ConnectionTrait>(
350        &self,
351        db: &C,
352        created_enums: &mut Vec<Statement>,
353    ) -> Result<(), DbErr> {
354        for stmt in self.enums.iter() {
355            let new_stmt = db.get_database_backend().build(stmt);
356            if !created_enums.iter().any(|s| s == &new_stmt) {
357                db.execute(stmt)?;
358                created_enums.push(new_stmt);
359            }
360        }
361        db.execute(&self.table)?;
362        for stmt in self.indexes.iter() {
363            db.execute(stmt)?;
364        }
365        Ok(())
366    }
367
368    // better to always compile this function
369    #[allow(dead_code)]
370    fn sync<C: ConnectionTrait>(
371        &self,
372        db: &C,
373        existing: &DiscoveredSchema,
374        created_enums: &mut Vec<Statement>,
375    ) -> Result<(), DbErr> {
376        let db_backend = db.get_database_backend();
377
378        // create enum before creating table
379        let existing_enums = existing.find_enums(self.schema_name.as_deref());
380        for stmt in self.enums.iter() {
381            let mut has_enum = false;
382            let new_stmt = db_backend.build(stmt);
383            for existing_enum in existing_enums {
384                if db_backend.build(existing_enum) == new_stmt {
385                    has_enum = true;
386                    // TODO add enum variants
387                    break;
388                }
389            }
390            if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
391                db.execute(stmt)?;
392                created_enums.push(new_stmt);
393            }
394        }
395        let table_name = get_table_name(self.table.get_table_name());
396        // Use schema-aware lookup: find existing table in the correct schema
397        let existing_table = existing.find_table(self.schema_name.as_deref(), &table_name);
398        if let Some(existing_table) = existing_table {
399            for column_def in self.table.get_columns() {
400                let mut column_exists = false;
401                for existing_column in existing_table.get_columns() {
402                    if column_def.get_column_name() == existing_column.get_column_name() {
403                        column_exists = true;
404                        break;
405                    }
406                }
407                if !column_exists {
408                    let mut renamed_from = "";
409                    if let Some(comment) = &column_def.get_column_spec().comment
410                        && let Some((_, suffix)) = comment.rsplit_once("renamed_from \"")
411                        && let Some((prefix, _)) = suffix.split_once('"')
412                    {
413                        renamed_from = prefix;
414                    }
415                    if renamed_from.is_empty() {
416                        db.execute(
417                            TableAlterStatement::new()
418                                .table(self.table.get_table_name().expect("Checked above").clone())
419                                .add_column(column_def.to_owned()),
420                        )?;
421                    } else {
422                        db.execute(
423                            TableAlterStatement::new()
424                                .table(self.table.get_table_name().expect("Checked above").clone())
425                                .rename_column(
426                                    renamed_from.to_owned(),
427                                    column_def.get_column_name(),
428                                ),
429                        )?;
430                    }
431                }
432            }
433            if db.get_database_backend() != DbBackend::Sqlite {
434                for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
435                    let mut key_exists = false;
436                    for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
437                        if compare_foreign_key(foreign_key, existing_key) {
438                            key_exists = true;
439                            break;
440                        }
441                    }
442                    if !key_exists {
443                        db.execute(foreign_key)?;
444                    }
445                }
446            }
447        } else {
448            db.execute(&self.table)?;
449        }
450        for stmt in self.indexes.iter() {
451            let mut has_index = false;
452            if let Some(existing_table) = existing_table {
453                for existing_index in existing_table.get_indexes() {
454                    if existing_index.get_index_spec().get_column_names()
455                        == stmt.get_index_spec().get_column_names()
456                    {
457                        has_index = true;
458                        break;
459                    }
460                }
461            }
462            if !has_index {
463                // shall we do alter table add constraint for unique index?
464                let mut stmt = stmt.clone();
465                stmt.if_not_exists();
466                db.execute(&stmt)?;
467            }
468        }
469        if let Some(existing_table) = existing_table {
470            // For columns with a column-level UNIQUE constraint (#[sea_orm(unique)]) that
471            // already exist in the table but do not yet have a unique index, create one.
472            for column_def in self.table.get_columns() {
473                if column_def.get_column_spec().unique {
474                    let col_name = column_def.get_column_name();
475                    let col_exists = existing_table
476                        .get_columns()
477                        .iter()
478                        .any(|c| c.get_column_name() == col_name);
479                    if !col_exists {
480                        // Column is being added in this sync pass; the ALTER TABLE ADD COLUMN
481                        // will include the UNIQUE inline, so no separate index needed.
482                        continue;
483                    }
484                    let already_unique = existing_table.get_indexes().iter().any(|idx| {
485                        if !idx.is_unique_key() {
486                            return false;
487                        }
488                        let cols = idx.get_index_spec().get_column_names();
489                        cols.len() == 1 && cols[0] == col_name
490                    });
491                    if !already_unique {
492                        let table_name =
493                            self.table.get_table_name().expect("table must have a name");
494                        let tbl_str = table_name.sea_orm_table().to_string();
495                        let table_ref = index_table_ref(table_name.clone(), db_backend);
496                        db.execute(
497                            Index::create()
498                                .name(format!("idx-{tbl_str}-{col_name}"))
499                                .table(table_ref)
500                                .col(col_name.into_iden())
501                                .unique()
502                                .if_not_exists(),
503                        )?;
504                    }
505                }
506            }
507        }
508        if let Some(existing_table) = existing_table {
509            // find all unique keys from existing table
510            // if it no longer exist in new schema, drop it
511            for existing_index in existing_table.get_indexes() {
512                if existing_index.is_unique_key() {
513                    let mut has_index = false;
514                    for stmt in self.indexes.iter() {
515                        if existing_index.get_index_spec().get_column_names()
516                            == stmt.get_index_spec().get_column_names()
517                        {
518                            has_index = true;
519                            break;
520                        }
521                    }
522                    // Also check if the unique index corresponds to a column-level UNIQUE
523                    // constraint (from #[sea_orm(unique)]). These are embedded in the CREATE
524                    // TABLE column definition and not tracked in self.indexes, so we must not
525                    // try to drop them during sync.
526                    if !has_index {
527                        let index_cols = existing_index.get_index_spec().get_column_names();
528                        if index_cols.len() == 1 {
529                            for column_def in self.table.get_columns() {
530                                if column_def.get_column_name() == index_cols[0]
531                                    && column_def.get_column_spec().unique
532                                {
533                                    has_index = true;
534                                    break;
535                                }
536                            }
537                        }
538                    }
539                    if !has_index
540                        && let Some(drop_existing) = existing_index
541                            .get_index_spec()
542                            .get_name()
543                            .map(|s| s.to_owned())
544                    {
545                        if db_backend == DbBackend::Postgres {
546                            // On PostgreSQL, unique indexes created via column-level UNIQUE
547                            // (e.g. ADD COLUMN ... UNIQUE) are backed by a named constraint.
548                            // DROP INDEX fails on constraint-owned indexes; use
549                            // ALTER TABLE ... DROP CONSTRAINT instead.
550                            db.execute(
551                                TableAlterStatement::new()
552                                    .table(
553                                        self.table.get_table_name().expect("Checked above").clone(),
554                                    )
555                                    .drop_constraint(drop_existing),
556                            )?;
557                        } else {
558                            db.execute(sea_query::Index::drop().name(drop_existing))?;
559                        }
560                    }
561                }
562            }
563        }
564        Ok(())
565    }
566
567    fn debug_print(
568        &self,
569        f: &mut std::fmt::Formatter<'_>,
570        backend: &DbBackend,
571    ) -> std::fmt::Result {
572        write!(f, "EntitySchemaInfo {{")?;
573        write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
574        write!(f, " enums: [")?;
575        for (i, stmt) in self.enums.iter().enumerate() {
576            if i > 0 {
577                write!(f, ", ")?;
578            }
579            write!(f, "{:?}", backend.build(stmt).to_string())?;
580        }
581        write!(f, " ]")?;
582        write!(f, " indexes: [")?;
583        for (i, stmt) in self.indexes.iter().enumerate() {
584            if i > 0 {
585                write!(f, ", ")?;
586            }
587            write!(f, "{:?}", backend.build(stmt).to_string())?;
588        }
589        write!(f, " ]")?;
590        write!(f, " }}")
591    }
592}
593
594fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
595    match table_ref {
596        Some(TableRef::Table(table_name, _)) => table_name.clone(),
597        None => panic!("Expect TableCreateStatement is properly built"),
598        _ => unreachable!("Unexpected {table_ref:?}"),
599    }
600}
601
602fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
603    let a = a.get_foreign_key();
604    let b = b.get_foreign_key();
605
606    a.get_name() == b.get_name()
607        || (a.get_ref_table() == b.get_ref_table()
608            && a.get_columns() == b.get_columns()
609            && a.get_ref_columns() == b.get_ref_columns())
610}