Skip to main content

sea_orm/schema/
builder.rs

1use super::{Schema, TopologicalSort};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4    ForeignKeyCreateStatement, Index, IndexCreateStatement, IntoIden, TableAlterStatement,
5    TableCreateStatement, TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8/// A schema builder that can take a registry of Entities and synchronize it with database.
9pub struct SchemaBuilder {
10    helper: Schema,
11    entities: Vec<EntitySchemaInfo>,
12}
13
14/// Schema info for Entity. Can be used to re-create schema in database.
15pub struct EntitySchemaInfo {
16    table: TableCreateStatement,
17    enums: Vec<TypeCreateStatement>,
18    indexes: Vec<IndexCreateStatement>,
19}
20
21impl std::fmt::Debug for SchemaBuilder {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "SchemaBuilder {{")?;
24        write!(f, " entities: [")?;
25        for (i, entity) in self.entities.iter().enumerate() {
26            if i > 0 {
27                write!(f, ", ")?;
28            }
29            entity.debug_print(f, &self.helper.backend)?;
30        }
31        write!(f, " ]")?;
32        write!(f, " }}")
33    }
34}
35
36impl std::fmt::Debug for EntitySchemaInfo {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        self.debug_print(f, &DbBackend::Sqlite)
39    }
40}
41
42impl SchemaBuilder {
43    /// Creates a new schema builder
44    pub fn new(schema: Schema) -> Self {
45        Self {
46            helper: schema,
47            entities: Default::default(),
48        }
49    }
50
51    /// Register an entity to this schema
52    pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
53        let entity = EntitySchemaInfo::new(entity, &self.helper);
54        if !self
55            .entities
56            .iter()
57            .any(|e| e.table.get_table_name() == entity.table.get_table_name())
58        {
59            self.entities.push(entity);
60        }
61        self
62    }
63
64    #[cfg(feature = "entity-registry")]
65    pub(crate) fn helper(&self) -> &Schema {
66        &self.helper
67    }
68
69    #[cfg(feature = "entity-registry")]
70    pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
71        self.entities.push(entity);
72    }
73
74    /// Synchronize the schema with database, will create missing tables, columns, unique keys, and foreign keys.
75    /// This operation is addition only, will not drop any table / columns.
76    #[cfg(feature = "schema-sync")]
77    #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
78    pub fn sync<C>(self, db: &C) -> Result<(), DbErr>
79    where
80        C: ConnectionTrait + sea_schema::Connection,
81    {
82        let _existing = match db.get_database_backend() {
83            #[cfg(feature = "sqlx-mysql")]
84            DbBackend::MySql => {
85                use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
86
87                let current_schema: String = db
88                    .query_one(
89                        sea_query::SelectStatement::new()
90                            .expr(sea_schema::mysql::MySql::get_current_schema()),
91                    )?
92                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
93                    .try_get_by_index(0)?;
94                let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
95
96                let schema = schema_discovery
97                    .discover_with(db)
98                    .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
99
100                DiscoveredSchema {
101                    tables: schema.tables.iter().map(|table| table.write()).collect(),
102                    enums: vec![],
103                }
104            }
105            #[cfg(feature = "sqlx-postgres")]
106            DbBackend::Postgres => {
107                use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
108
109                let current_schema: String = db
110                    .query_one(
111                        sea_query::SelectStatement::new()
112                            .expr(sea_schema::postgres::Postgres::get_current_schema()),
113                    )?
114                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
115                    .try_get_by_index(0)?;
116                let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
117
118                let schema = schema_discovery
119                    .discover_with(db)
120                    .map_err(|err| DbErr::Query(crate::RuntimeErr::SqlxError(err.into())))?;
121
122                DiscoveredSchema {
123                    tables: schema.tables.iter().map(|table| table.write()).collect(),
124                    enums: schema.enums.iter().map(|def| def.write()).collect(),
125                }
126            }
127            #[cfg(feature = "sqlx-sqlite")]
128            DbBackend::Sqlite => {
129                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
130                let schema = SchemaDiscovery::discover_with(db)
131                    .map_err(|err| {
132                        DbErr::Query(match err {
133                            SqliteDiscoveryError::SqlxError(err) => {
134                                crate::RuntimeErr::SqlxError(err.into())
135                            }
136                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
137                        })
138                    })?
139                    .merge_indexes_into_table();
140                DiscoveredSchema {
141                    tables: schema.tables.iter().map(|table| table.write()).collect(),
142                    enums: vec![],
143                }
144            }
145            #[cfg(feature = "rusqlite")]
146            DbBackend::Sqlite => {
147                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
148                let schema = SchemaDiscovery::discover_with(db)
149                    .map_err(|err| {
150                        DbErr::Query(match err {
151                            SqliteDiscoveryError::RusqliteError(err) => {
152                                crate::RuntimeErr::Rusqlite(err.into())
153                            }
154                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
155                        })
156                    })?
157                    .merge_indexes_into_table();
158                DiscoveredSchema {
159                    tables: schema.tables.iter().map(|table| table.write()).collect(),
160                    enums: vec![],
161                }
162            }
163            #[allow(unreachable_patterns)]
164            other => {
165                return Err(DbErr::BackendNotSupported {
166                    db: other.as_str(),
167                    ctx: "SchemaBuilder::sync",
168                });
169            }
170        };
171
172        #[allow(unreachable_code)]
173        let mut created_enums: Vec<Statement> = Default::default();
174
175        #[allow(unreachable_code)]
176        for table_name in self.sorted_tables() {
177            if let Some(entity) = self
178                .entities
179                .iter()
180                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
181            {
182                entity.sync(db, &_existing, &mut created_enums)?;
183            }
184        }
185
186        Ok(())
187    }
188
189    /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys.
190    /// Will fail if any table already exists. Use [`sync`] if you want an incremental version that can perform schema diff.
191    pub fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
192        let mut created_enums: Vec<Statement> = Default::default();
193
194        for table_name in self.sorted_tables() {
195            if let Some(entity) = self
196                .entities
197                .iter()
198                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
199            {
200                entity.apply(db, &mut created_enums)?;
201            }
202        }
203
204        Ok(())
205    }
206
207    fn sorted_tables(&self) -> Vec<TableName> {
208        let mut sorter = TopologicalSort::<TableName>::new();
209
210        for entity in self.entities.iter() {
211            let table_name = get_table_name(entity.table.get_table_name());
212            sorter.insert(table_name);
213        }
214        for entity in self.entities.iter() {
215            let self_table = get_table_name(entity.table.get_table_name());
216            for fk in entity.table.get_foreign_key_create_stmts().iter() {
217                let fk = fk.get_foreign_key();
218                let ref_table = get_table_name(fk.get_ref_table());
219                if self_table != ref_table {
220                    // self cycle is okay
221                    sorter.add_dependency(ref_table, self_table.clone());
222                }
223            }
224        }
225        let mut sorted = Vec::new();
226        while let Some(i) = sorter.pop() {
227            sorted.push(i);
228        }
229        if sorted.len() != self.entities.len() {
230            // push leftover tables
231            for entity in self.entities.iter() {
232                let table_name = get_table_name(entity.table.get_table_name());
233                if !sorted.contains(&table_name) {
234                    sorted.push(table_name);
235                }
236            }
237        }
238
239        sorted
240    }
241}
242
243struct DiscoveredSchema {
244    tables: Vec<TableCreateStatement>,
245    enums: Vec<TypeCreateStatement>,
246}
247
248impl EntitySchemaInfo {
249    /// Creates a EntitySchemaInfo object given a generic Entity.
250    pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
251        Self {
252            table: helper.create_table_from_entity(entity),
253            enums: helper.create_enum_from_entity(entity),
254            indexes: helper.create_index_from_entity(entity),
255        }
256    }
257
258    fn apply<C: ConnectionTrait>(
259        &self,
260        db: &C,
261        created_enums: &mut Vec<Statement>,
262    ) -> Result<(), DbErr> {
263        for stmt in self.enums.iter() {
264            let new_stmt = db.get_database_backend().build(stmt);
265            if !created_enums.iter().any(|s| s == &new_stmt) {
266                db.execute(stmt)?;
267                created_enums.push(new_stmt);
268            }
269        }
270        db.execute(&self.table)?;
271        for stmt in self.indexes.iter() {
272            db.execute(stmt)?;
273        }
274        Ok(())
275    }
276
277    // better to always compile this function
278    #[allow(dead_code)]
279    fn sync<C: ConnectionTrait>(
280        &self,
281        db: &C,
282        existing: &DiscoveredSchema,
283        created_enums: &mut Vec<Statement>,
284    ) -> Result<(), DbErr> {
285        let db_backend = db.get_database_backend();
286
287        // create enum before creating table
288        for stmt in self.enums.iter() {
289            let mut has_enum = false;
290            let new_stmt = db_backend.build(stmt);
291            for existing_enum in &existing.enums {
292                if db_backend.build(existing_enum) == new_stmt {
293                    has_enum = true;
294                    // TODO add enum variants
295                    break;
296                }
297            }
298            if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
299                db.execute(stmt)?;
300                created_enums.push(new_stmt);
301            }
302        }
303        let table_name = get_table_name(self.table.get_table_name());
304        let mut existing_table = None;
305        for tbl in &existing.tables {
306            if get_table_name(tbl.get_table_name()) == table_name {
307                existing_table = Some(tbl);
308                break;
309            }
310        }
311        if let Some(existing_table) = existing_table {
312            for column_def in self.table.get_columns() {
313                let mut column_exists = false;
314                for existing_column in existing_table.get_columns() {
315                    if column_def.get_column_name() == existing_column.get_column_name() {
316                        column_exists = true;
317                        break;
318                    }
319                }
320                if !column_exists {
321                    let mut renamed_from = "";
322                    if let Some(comment) = &column_def.get_column_spec().comment {
323                        if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") {
324                            if let Some((prefix, _)) = suffix.split_once('"') {
325                                renamed_from = prefix;
326                            }
327                        }
328                    }
329                    if renamed_from.is_empty() {
330                        db.execute(
331                            TableAlterStatement::new()
332                                .table(self.table.get_table_name().expect("Checked above").clone())
333                                .add_column(column_def.to_owned()),
334                        )?;
335                    } else {
336                        db.execute(
337                            TableAlterStatement::new()
338                                .table(self.table.get_table_name().expect("Checked above").clone())
339                                .rename_column(
340                                    renamed_from.to_owned(),
341                                    column_def.get_column_name(),
342                                ),
343                        )?;
344                    }
345                }
346            }
347            if db.get_database_backend() != DbBackend::Sqlite {
348                for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
349                    let mut key_exists = false;
350                    for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
351                        if compare_foreign_key(foreign_key, existing_key) {
352                            key_exists = true;
353                            break;
354                        }
355                    }
356                    if !key_exists {
357                        db.execute(foreign_key)?;
358                    }
359                }
360            }
361        } else {
362            db.execute(&self.table)?;
363        }
364        for stmt in self.indexes.iter() {
365            let mut has_index = false;
366            if let Some(existing_table) = existing_table {
367                for existing_index in existing_table.get_indexes() {
368                    if existing_index.get_index_spec().get_column_names()
369                        == stmt.get_index_spec().get_column_names()
370                    {
371                        has_index = true;
372                        break;
373                    }
374                }
375            }
376            if !has_index {
377                // shall we do alter table add constraint for unique index?
378                let mut stmt = stmt.clone();
379                stmt.if_not_exists();
380                db.execute(&stmt)?;
381            }
382        }
383        if let Some(existing_table) = existing_table {
384            // For columns with a column-level UNIQUE constraint (#[sea_orm(unique)]) that
385            // already exist in the table but do not yet have a unique index, create one.
386            for column_def in self.table.get_columns() {
387                if column_def.get_column_spec().unique {
388                    let col_name = column_def.get_column_name();
389                    let col_exists = existing_table
390                        .get_columns()
391                        .iter()
392                        .any(|c| c.get_column_name() == col_name);
393                    if !col_exists {
394                        // Column is being added in this sync pass; the ALTER TABLE ADD COLUMN
395                        // will include the UNIQUE inline, so no separate index needed.
396                        continue;
397                    }
398                    let already_unique = existing_table.get_indexes().iter().any(|idx| {
399                        if !idx.is_unique_key() {
400                            return false;
401                        }
402                        let cols = idx.get_index_spec().get_column_names();
403                        cols.len() == 1 && cols[0] == col_name
404                    });
405                    if !already_unique {
406                        let table_name =
407                            self.table.get_table_name().expect("table must have a name");
408                        let tbl_str = table_name.sea_orm_table().to_string();
409                        let table_ref = table_name.clone();
410                        db.execute(
411                            Index::create()
412                                .name(format!("idx-{tbl_str}-{col_name}"))
413                                .table(table_ref)
414                                .col(col_name.into_iden())
415                                .unique()
416                                .if_not_exists(),
417                        )?;
418                    }
419                }
420            }
421        }
422        if let Some(existing_table) = existing_table {
423            // find all unique keys from existing table
424            // if it no longer exist in new schema, drop it
425            for existing_index in existing_table.get_indexes() {
426                if existing_index.is_unique_key() {
427                    let mut has_index = false;
428                    for stmt in self.indexes.iter() {
429                        if existing_index.get_index_spec().get_column_names()
430                            == stmt.get_index_spec().get_column_names()
431                        {
432                            has_index = true;
433                            break;
434                        }
435                    }
436                    // Also check if the unique index corresponds to a column-level UNIQUE
437                    // constraint (from #[sea_orm(unique)]). These are embedded in the CREATE
438                    // TABLE column definition and not tracked in self.indexes, so we must not
439                    // try to drop them during sync.
440                    if !has_index {
441                        let index_cols = existing_index.get_index_spec().get_column_names();
442                        if index_cols.len() == 1 {
443                            for column_def in self.table.get_columns() {
444                                if column_def.get_column_name() == index_cols[0]
445                                    && column_def.get_column_spec().unique
446                                {
447                                    has_index = true;
448                                    break;
449                                }
450                            }
451                        }
452                    }
453                    if !has_index {
454                        if let Some(drop_existing) = existing_index.get_index_spec().get_name() {
455                            db.execute(sea_query::Index::drop().name(drop_existing))?;
456                        }
457                    }
458                }
459            }
460        }
461        Ok(())
462    }
463
464    fn debug_print(
465        &self,
466        f: &mut std::fmt::Formatter<'_>,
467        backend: &DbBackend,
468    ) -> std::fmt::Result {
469        write!(f, "EntitySchemaInfo {{")?;
470        write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
471        write!(f, " enums: [")?;
472        for (i, stmt) in self.enums.iter().enumerate() {
473            if i > 0 {
474                write!(f, ", ")?;
475            }
476            write!(f, "{:?}", backend.build(stmt).to_string())?;
477        }
478        write!(f, " ]")?;
479        write!(f, " indexes: [")?;
480        for (i, stmt) in self.indexes.iter().enumerate() {
481            if i > 0 {
482                write!(f, ", ")?;
483            }
484            write!(f, "{:?}", backend.build(stmt).to_string())?;
485        }
486        write!(f, " ]")?;
487        write!(f, " }}")
488    }
489}
490
491fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
492    match table_ref {
493        Some(TableRef::Table(table_name, _)) => table_name.clone(),
494        None => panic!("Expect TableCreateStatement is properly built"),
495        _ => unreachable!("Unexpected {table_ref:?}"),
496    }
497}
498
499fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
500    let a = a.get_foreign_key();
501    let b = b.get_foreign_key();
502
503    a.get_name() == b.get_name()
504        || (a.get_ref_table() == b.get_ref_table()
505            && a.get_columns() == b.get_columns()
506            && a.get_ref_columns() == b.get_ref_columns())
507}