sea_orm/schema/
builder.rs

1use super::{Schema, TopologicalSort};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4    ForeignKeyCreateStatement, IndexCreateStatement, TableAlterStatement, TableCreateStatement,
5    TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8/// A schema builder that can take a registry of Entities and synchronize it with database.
9pub struct SchemaBuilder {
10    helper: Schema,
11    entities: Vec<EntitySchemaInfo>,
12}
13
14/// Schema info for Entity. Can be used to re-create schema in database.
15pub struct EntitySchemaInfo {
16    table: TableCreateStatement,
17    enums: Vec<TypeCreateStatement>,
18    indexes: Vec<IndexCreateStatement>,
19}
20
21impl std::fmt::Debug for SchemaBuilder {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "SchemaBuilder {{")?;
24        write!(f, " entities: [")?;
25        for (i, entity) in self.entities.iter().enumerate() {
26            if i > 0 {
27                write!(f, ", ")?;
28            }
29            entity.debug_print(f, &self.helper.backend)?;
30        }
31        write!(f, " ]")?;
32        write!(f, " }}")
33    }
34}
35
36impl std::fmt::Debug for EntitySchemaInfo {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        self.debug_print(f, &DbBackend::Sqlite)
39    }
40}
41
42impl SchemaBuilder {
43    /// Creates a new schema builder
44    pub fn new(schema: Schema) -> Self {
45        Self {
46            helper: schema,
47            entities: Default::default(),
48        }
49    }
50
51    /// Register an entity to this schema
52    pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
53        let entity = EntitySchemaInfo::new(entity, &self.helper);
54        if !self
55            .entities
56            .iter()
57            .any(|e| e.table.get_table_name() == entity.table.get_table_name())
58        {
59            self.entities.push(entity);
60        }
61        self
62    }
63
64    #[cfg(feature = "entity-registry")]
65    pub(crate) fn helper(&self) -> &Schema {
66        &self.helper
67    }
68
69    #[cfg(feature = "entity-registry")]
70    pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
71        self.entities.push(entity);
72    }
73
74    /// Synchronize the schema with database, will create missing tables, columns, unique keys, and foreign keys.
75    /// This operation is addition only, will not drop any table / columns.
76    #[cfg(feature = "schema-sync")]
77    #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
78    pub fn sync<C>(self, db: &C) -> Result<(), DbErr>
79    where
80        C: ConnectionTrait + sea_schema::Connection,
81    {
82        let _existing = match db.get_database_backend() {
83            #[cfg(feature = "sqlx-mysql")]
84            DbBackend::MySql => {
85                use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
86
87                let current_schema: String = db
88                    .query_one(
89                        sea_query::SelectStatement::new()
90                            .expr(sea_schema::mysql::MySql::get_current_schema()),
91                    )?
92                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
93                    .try_get_by_index(0)?;
94                let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
95
96                let schema = schema_discovery
97                    .discover_with(db)
98                    .map_err(|err| DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}"))))?;
99
100                DiscoveredSchema {
101                    tables: schema.tables.iter().map(|table| table.write()).collect(),
102                    enums: vec![],
103                }
104            }
105            #[cfg(feature = "sqlx-postgres")]
106            DbBackend::Postgres => {
107                use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
108
109                let current_schema: String = db
110                    .query_one(
111                        sea_query::SelectStatement::new()
112                            .expr(sea_schema::postgres::Postgres::get_current_schema()),
113                    )?
114                    .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
115                    .try_get_by_index(0)?;
116                let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
117
118                let schema = schema_discovery
119                    .discover_with(db)
120                    .map_err(|err| DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}"))))?;
121
122                DiscoveredSchema {
123                    tables: schema.tables.iter().map(|table| table.write()).collect(),
124                    enums: schema.enums.iter().map(|def| def.write()).collect(),
125                }
126            }
127            #[cfg(feature = "sqlx-sqlite")]
128            DbBackend::Sqlite => {
129                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
130                let schema = SchemaDiscovery::discover_with(db)
131                    .map_err(|err| {
132                        DbErr::Query(match err {
133                            SqliteDiscoveryError::SqlxError(err) => {
134                                crate::RuntimeErr::SqlxError(err.into())
135                            }
136                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
137                        })
138                    })?
139                    .merge_indexes_into_table();
140                DiscoveredSchema {
141                    tables: schema.tables.iter().map(|table| table.write()).collect(),
142                    enums: vec![],
143                }
144            }
145            #[cfg(feature = "rusqlite")]
146            DbBackend::Sqlite => {
147                use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
148                let schema = SchemaDiscovery::discover_with(db)
149                    .map_err(|err| {
150                        DbErr::Query(match err {
151                            SqliteDiscoveryError::RusqliteError(err) => {
152                                crate::RuntimeErr::Rusqlite(err.into())
153                            }
154                            _ => crate::RuntimeErr::Internal(format!("{err:?}")),
155                        })
156                    })?
157                    .merge_indexes_into_table();
158                DiscoveredSchema {
159                    tables: schema.tables.iter().map(|table| table.write()).collect(),
160                    enums: vec![],
161                }
162            }
163            #[allow(unreachable_patterns)]
164            other => {
165                return Err(DbErr::BackendNotSupported {
166                    db: other.as_str(),
167                    ctx: "SchemaBuilder::sync",
168                });
169            }
170        };
171
172        #[allow(unreachable_code)]
173        let mut created_enums: Vec<Statement> = Default::default();
174
175        #[allow(unreachable_code)]
176        for table_name in self.sorted_tables() {
177            if let Some(entity) = self
178                .entities
179                .iter()
180                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
181            {
182                entity.sync(db, &_existing, &mut created_enums)?;
183            }
184        }
185
186        Ok(())
187    }
188
189    /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys.
190    /// Will fail if any table already exists. Use [`sync`] if you want an incremental version that can perform schema diff.
191    pub fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
192        let mut created_enums: Vec<Statement> = Default::default();
193
194        for table_name in self.sorted_tables() {
195            if let Some(entity) = self
196                .entities
197                .iter()
198                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
199            {
200                entity.apply(db, &mut created_enums)?;
201            }
202        }
203
204        Ok(())
205    }
206
207    fn sorted_tables(&self) -> Vec<TableName> {
208        let mut sorter = TopologicalSort::<TableName>::new();
209
210        for entity in self.entities.iter() {
211            let table_name = get_table_name(entity.table.get_table_name());
212            sorter.insert(table_name);
213        }
214        for entity in self.entities.iter() {
215            let self_table = get_table_name(entity.table.get_table_name());
216            for fk in entity.table.get_foreign_key_create_stmts().iter() {
217                let fk = fk.get_foreign_key();
218                let ref_table = get_table_name(fk.get_ref_table());
219                if self_table != ref_table {
220                    // self cycle is okay
221                    sorter.add_dependency(ref_table, self_table.clone());
222                }
223            }
224        }
225        let mut sorted = Vec::new();
226        while let Some(i) = sorter.pop() {
227            sorted.push(i);
228        }
229        if sorted.len() != self.entities.len() {
230            // push leftover tables
231            for entity in self.entities.iter() {
232                let table_name = get_table_name(entity.table.get_table_name());
233                if !sorted.contains(&table_name) {
234                    sorted.push(table_name);
235                }
236            }
237        }
238
239        sorted
240    }
241}
242
243struct DiscoveredSchema {
244    tables: Vec<TableCreateStatement>,
245    enums: Vec<TypeCreateStatement>,
246}
247
248impl EntitySchemaInfo {
249    /// Creates a EntitySchemaInfo object given a generic Entity.
250    pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
251        Self {
252            table: helper.create_table_from_entity(entity),
253            enums: helper.create_enum_from_entity(entity),
254            indexes: helper.create_index_from_entity(entity),
255        }
256    }
257
258    fn apply<C: ConnectionTrait>(
259        &self,
260        db: &C,
261        created_enums: &mut Vec<Statement>,
262    ) -> Result<(), DbErr> {
263        for stmt in self.enums.iter() {
264            let new_stmt = db.get_database_backend().build(stmt);
265            if !created_enums.iter().any(|s| s == &new_stmt) {
266                db.execute(stmt)?;
267                created_enums.push(new_stmt);
268            }
269        }
270        db.execute(&self.table)?;
271        for stmt in self.indexes.iter() {
272            db.execute(stmt)?;
273        }
274        Ok(())
275    }
276
277    // better to always compile this function
278    #[allow(dead_code)]
279    fn sync<C: ConnectionTrait>(
280        &self,
281        db: &C,
282        existing: &DiscoveredSchema,
283        created_enums: &mut Vec<Statement>,
284    ) -> Result<(), DbErr> {
285        let db_backend = db.get_database_backend();
286
287        // create enum before creating table
288        for stmt in self.enums.iter() {
289            let mut has_enum = false;
290            let new_stmt = db_backend.build(stmt);
291            for exsiting_enum in &existing.enums {
292                if db_backend.build(exsiting_enum) == new_stmt {
293                    has_enum = true;
294                    // TODO add enum variants
295                    break;
296                }
297            }
298            if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
299                db.execute(stmt)?;
300                created_enums.push(new_stmt);
301            }
302        }
303        let table_name = get_table_name(self.table.get_table_name());
304        let mut existing_table = None;
305        for tbl in &existing.tables {
306            if get_table_name(tbl.get_table_name()) == table_name {
307                existing_table = Some(tbl);
308                break;
309            }
310        }
311        if let Some(existing_table) = existing_table {
312            for column_def in self.table.get_columns() {
313                let mut column_exists = false;
314                for existing_column in existing_table.get_columns() {
315                    if column_def.get_column_name() == existing_column.get_column_name() {
316                        column_exists = true;
317                        break;
318                    }
319                }
320                if !column_exists {
321                    let mut renamed_from = "";
322                    if let Some(comment) = &column_def.get_column_spec().comment {
323                        if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") {
324                            if let Some((prefix, _)) = suffix.split_once('"') {
325                                renamed_from = prefix;
326                            }
327                        }
328                    }
329                    if renamed_from.is_empty() {
330                        db.execute(
331                            TableAlterStatement::new()
332                                .table(self.table.get_table_name().expect("Checked above").clone())
333                                .add_column(column_def.to_owned()),
334                        )?;
335                    } else {
336                        db.execute(
337                            TableAlterStatement::new()
338                                .table(self.table.get_table_name().expect("Checked above").clone())
339                                .rename_column(
340                                    renamed_from.to_owned(),
341                                    column_def.get_column_name(),
342                                ),
343                        )?;
344                    }
345                }
346            }
347            if db.get_database_backend() != DbBackend::Sqlite {
348                for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
349                    let mut key_exists = false;
350                    for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
351                        if compare_foreign_key(foreign_key, existing_key) {
352                            key_exists = true;
353                            break;
354                        }
355                    }
356                    if !key_exists {
357                        db.execute(foreign_key)?;
358                    }
359                }
360            }
361        } else {
362            db.execute(&self.table)?;
363        }
364        for stmt in self.indexes.iter() {
365            let mut has_index = false;
366            if let Some(existing_table) = existing_table {
367                for exsiting_index in existing_table.get_indexes() {
368                    if exsiting_index.get_index_spec().get_column_names()
369                        == stmt.get_index_spec().get_column_names()
370                    {
371                        has_index = true;
372                        break;
373                    }
374                }
375            }
376            if !has_index {
377                // shall we do alter table add constraint for unique index?
378                let mut stmt = stmt.clone();
379                stmt.if_not_exists();
380                db.execute(&stmt)?;
381            }
382        }
383        if let Some(existing_table) = existing_table {
384            // find all unique keys from existing table
385            // if it no longer exist in new schema, drop it
386            for exsiting_index in existing_table.get_indexes() {
387                if exsiting_index.is_unique_key() {
388                    let mut has_index = false;
389                    for stmt in self.indexes.iter() {
390                        if exsiting_index.get_index_spec().get_column_names()
391                            == stmt.get_index_spec().get_column_names()
392                        {
393                            has_index = true;
394                            break;
395                        }
396                    }
397                    if !has_index {
398                        if let Some(drop_existing) = exsiting_index.get_index_spec().get_name() {
399                            db.execute(sea_query::Index::drop().name(drop_existing))?;
400                        }
401                    }
402                }
403            }
404        }
405        Ok(())
406    }
407
408    fn debug_print(
409        &self,
410        f: &mut std::fmt::Formatter<'_>,
411        backend: &DbBackend,
412    ) -> std::fmt::Result {
413        write!(f, "EntitySchemaInfo {{")?;
414        write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
415        write!(f, " enums: [")?;
416        for (i, stmt) in self.enums.iter().enumerate() {
417            if i > 0 {
418                write!(f, ", ")?;
419            }
420            write!(f, "{:?}", backend.build(stmt).to_string())?;
421        }
422        write!(f, " ]")?;
423        write!(f, " indexes: [")?;
424        for (i, stmt) in self.indexes.iter().enumerate() {
425            if i > 0 {
426                write!(f, ", ")?;
427            }
428            write!(f, "{:?}", backend.build(stmt).to_string())?;
429        }
430        write!(f, " ]")?;
431        write!(f, " }}")
432    }
433}
434
435fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
436    match table_ref {
437        Some(TableRef::Table(table_name, _)) => table_name.clone(),
438        None => panic!("Expect TableCreateStatement is properly built"),
439        _ => unreachable!("Unexpected {table_ref:?}"),
440    }
441}
442
443fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
444    let a = a.get_foreign_key();
445    let b = b.get_foreign_key();
446
447    a.get_name() == b.get_name()
448        || (a.get_ref_table() == b.get_ref_table()
449            && a.get_columns() == b.get_columns()
450            && a.get_ref_columns() == b.get_ref_columns())
451}