sea_orm/schema/
builder.rs

1use super::{Schema, TopologicalSort};
2use crate::{ConnectionTrait, DbBackend, DbErr, EntityTrait, Statement};
3use sea_query::{
4    ForeignKeyCreateStatement, IndexCreateStatement, TableAlterStatement, TableCreateStatement,
5    TableName, TableRef, extension::postgres::TypeCreateStatement,
6};
7
8/// A schema builder that can take a registry of Entities and synchronize it with database.
9pub struct SchemaBuilder {
10    helper: Schema,
11    entities: Vec<EntitySchemaInfo>,
12}
13
14/// Schema info for Entity. Can be used to re-create schema in database.
15pub struct EntitySchemaInfo {
16    table: TableCreateStatement,
17    enums: Vec<TypeCreateStatement>,
18    indexes: Vec<IndexCreateStatement>,
19}
20
21impl std::fmt::Debug for SchemaBuilder {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "SchemaBuilder {{")?;
24        write!(f, " entities: [")?;
25        for (i, entity) in self.entities.iter().enumerate() {
26            if i > 0 {
27                write!(f, ", ")?;
28            }
29            entity.debug_print(f, &self.helper.backend)?;
30        }
31        write!(f, " ]")?;
32        write!(f, " }}")
33    }
34}
35
36impl std::fmt::Debug for EntitySchemaInfo {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        self.debug_print(f, &DbBackend::Sqlite)
39    }
40}
41
42impl SchemaBuilder {
43    /// Creates a new schema builder
44    pub fn new(schema: Schema) -> Self {
45        Self {
46            helper: schema,
47            entities: Default::default(),
48        }
49    }
50
51    /// Register an entity to this schema
52    pub fn register<E: EntityTrait>(mut self, entity: E) -> Self {
53        let entity = EntitySchemaInfo::new(entity, &self.helper);
54        if !self
55            .entities
56            .iter()
57            .any(|e| e.table.get_table_name() == entity.table.get_table_name())
58        {
59            self.entities.push(entity);
60        }
61        self
62    }
63
64    #[cfg(feature = "entity-registry")]
65    pub(crate) fn helper(&self) -> &Schema {
66        &self.helper
67    }
68
69    #[cfg(feature = "entity-registry")]
70    pub(crate) fn register_entity(&mut self, entity: EntitySchemaInfo) {
71        self.entities.push(entity);
72    }
73
74    /// Synchronize the schema with database, will create missing tables, columns, unique keys, and foreign keys.
75    /// This operation is addition only, will not drop any table / columns.
76    #[cfg(feature = "schema-sync")]
77    #[cfg_attr(docsrs, doc(cfg(feature = "schema-sync")))]
78    pub async fn sync<C>(self, db: &C) -> Result<(), DbErr>
79    where
80        C: ConnectionTrait + sea_schema::Connection,
81    {
82        let _existing =
83            match db.get_database_backend() {
84                #[cfg(feature = "sqlx-mysql")]
85                DbBackend::MySql => {
86                    use sea_schema::{mysql::discovery::SchemaDiscovery, probe::SchemaProbe};
87
88                    let current_schema: String = db
89                        .query_one(
90                            sea_query::SelectStatement::new()
91                                .expr(sea_schema::mysql::MySql::get_current_schema()),
92                        )
93                        .await?
94                        .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
95                        .try_get_by_index(0)?;
96                    let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
97
98                    let schema = schema_discovery.discover_with(db).await.map_err(|err| {
99                        DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}")))
100                    })?;
101
102                    DiscoveredSchema {
103                        tables: schema.tables.iter().map(|table| table.write()).collect(),
104                        enums: vec![],
105                    }
106                }
107                #[cfg(feature = "sqlx-postgres")]
108                DbBackend::Postgres => {
109                    use sea_schema::{postgres::discovery::SchemaDiscovery, probe::SchemaProbe};
110
111                    let current_schema: String = db
112                        .query_one(
113                            sea_query::SelectStatement::new()
114                                .expr(sea_schema::postgres::Postgres::get_current_schema()),
115                        )
116                        .await?
117                        .ok_or_else(|| DbErr::RecordNotFound("Can't get current schema".into()))?
118                        .try_get_by_index(0)?;
119                    let schema_discovery = SchemaDiscovery::new_no_exec(&current_schema);
120
121                    let schema = schema_discovery.discover_with(db).await.map_err(|err| {
122                        DbErr::Query(crate::RuntimeErr::Internal(format!("{err:?}")))
123                    })?;
124
125                    DiscoveredSchema {
126                        tables: schema.tables.iter().map(|table| table.write()).collect(),
127                        enums: schema.enums.iter().map(|def| def.write()).collect(),
128                    }
129                }
130                #[cfg(feature = "sqlx-sqlite")]
131                DbBackend::Sqlite => {
132                    use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
133                    let schema = SchemaDiscovery::discover_with(db)
134                        .await
135                        .map_err(|err| {
136                            DbErr::Query(match err {
137                                SqliteDiscoveryError::SqlxError(err) => {
138                                    crate::RuntimeErr::SqlxError(err.into())
139                                }
140                                _ => crate::RuntimeErr::Internal(format!("{err:?}")),
141                            })
142                        })?
143                        .merge_indexes_into_table();
144                    DiscoveredSchema {
145                        tables: schema.tables.iter().map(|table| table.write()).collect(),
146                        enums: vec![],
147                    }
148                }
149                #[cfg(feature = "rusqlite")]
150                DbBackend::Sqlite => {
151                    use sea_schema::sqlite::{SqliteDiscoveryError, discovery::SchemaDiscovery};
152                    let schema = SchemaDiscovery::discover_with(db)
153                        .map_err(|err| {
154                            DbErr::Query(match err {
155                                SqliteDiscoveryError::RusqliteError(err) => {
156                                    crate::RuntimeErr::Rusqlite(err.into())
157                                }
158                                _ => crate::RuntimeErr::Internal(format!("{err:?}")),
159                            })
160                        })?
161                        .merge_indexes_into_table();
162                    DiscoveredSchema {
163                        tables: schema.tables.iter().map(|table| table.write()).collect(),
164                        enums: vec![],
165                    }
166                }
167                #[allow(unreachable_patterns)]
168                other => {
169                    return Err(DbErr::BackendNotSupported {
170                        db: other.as_str(),
171                        ctx: "SchemaBuilder::sync",
172                    });
173                }
174            };
175
176        #[allow(unreachable_code)]
177        let mut created_enums: Vec<Statement> = Default::default();
178
179        #[allow(unreachable_code)]
180        for table_name in self.sorted_tables() {
181            if let Some(entity) = self
182                .entities
183                .iter()
184                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
185            {
186                entity.sync(db, &_existing, &mut created_enums).await?;
187            }
188        }
189
190        Ok(())
191    }
192
193    /// Apply this schema to a database, will create all registered tables, columns, unique keys, and foreign keys.
194    /// Will fail if any table already exists. Use [`sync`] if you want an incremental version that can perform schema diff.
195    pub async fn apply<C: ConnectionTrait>(self, db: &C) -> Result<(), DbErr> {
196        let mut created_enums: Vec<Statement> = Default::default();
197
198        for table_name in self.sorted_tables() {
199            if let Some(entity) = self
200                .entities
201                .iter()
202                .find(|entity| table_name == get_table_name(entity.table.get_table_name()))
203            {
204                entity.apply(db, &mut created_enums).await?;
205            }
206        }
207
208        Ok(())
209    }
210
211    fn sorted_tables(&self) -> Vec<TableName> {
212        let mut sorter = TopologicalSort::<TableName>::new();
213
214        for entity in self.entities.iter() {
215            let table_name = get_table_name(entity.table.get_table_name());
216            sorter.insert(table_name);
217        }
218        for entity in self.entities.iter() {
219            let self_table = get_table_name(entity.table.get_table_name());
220            for fk in entity.table.get_foreign_key_create_stmts().iter() {
221                let fk = fk.get_foreign_key();
222                let ref_table = get_table_name(fk.get_ref_table());
223                if self_table != ref_table {
224                    // self cycle is okay
225                    sorter.add_dependency(ref_table, self_table.clone());
226                }
227            }
228        }
229        let mut sorted = Vec::new();
230        while let Some(i) = sorter.pop() {
231            sorted.push(i);
232        }
233        if sorted.len() != self.entities.len() {
234            // push leftover tables
235            for entity in self.entities.iter() {
236                let table_name = get_table_name(entity.table.get_table_name());
237                if !sorted.contains(&table_name) {
238                    sorted.push(table_name);
239                }
240            }
241        }
242
243        sorted
244    }
245}
246
247struct DiscoveredSchema {
248    tables: Vec<TableCreateStatement>,
249    enums: Vec<TypeCreateStatement>,
250}
251
252impl EntitySchemaInfo {
253    /// Creates a EntitySchemaInfo object given a generic Entity.
254    pub fn new<E: EntityTrait>(entity: E, helper: &Schema) -> Self {
255        Self {
256            table: helper.create_table_from_entity(entity),
257            enums: helper.create_enum_from_entity(entity),
258            indexes: helper.create_index_from_entity(entity),
259        }
260    }
261
262    async fn apply<C: ConnectionTrait>(
263        &self,
264        db: &C,
265        created_enums: &mut Vec<Statement>,
266    ) -> Result<(), DbErr> {
267        for stmt in self.enums.iter() {
268            let new_stmt = db.get_database_backend().build(stmt);
269            if !created_enums.iter().any(|s| s == &new_stmt) {
270                db.execute(stmt).await?;
271                created_enums.push(new_stmt);
272            }
273        }
274        db.execute(&self.table).await?;
275        for stmt in self.indexes.iter() {
276            db.execute(stmt).await?;
277        }
278        Ok(())
279    }
280
281    // better to always compile this function
282    #[allow(dead_code)]
283    async fn sync<C: ConnectionTrait>(
284        &self,
285        db: &C,
286        existing: &DiscoveredSchema,
287        created_enums: &mut Vec<Statement>,
288    ) -> Result<(), DbErr> {
289        let db_backend = db.get_database_backend();
290
291        // create enum before creating table
292        for stmt in self.enums.iter() {
293            let mut has_enum = false;
294            let new_stmt = db_backend.build(stmt);
295            for exsiting_enum in &existing.enums {
296                if db_backend.build(exsiting_enum) == new_stmt {
297                    has_enum = true;
298                    // TODO add enum variants
299                    break;
300                }
301            }
302            if !has_enum && !created_enums.iter().any(|s| s == &new_stmt) {
303                db.execute(stmt).await?;
304                created_enums.push(new_stmt);
305            }
306        }
307        let table_name = get_table_name(self.table.get_table_name());
308        let mut existing_table = None;
309        for tbl in &existing.tables {
310            if get_table_name(tbl.get_table_name()) == table_name {
311                existing_table = Some(tbl);
312                break;
313            }
314        }
315        if let Some(existing_table) = existing_table {
316            for column_def in self.table.get_columns() {
317                let mut column_exists = false;
318                for existing_column in existing_table.get_columns() {
319                    if column_def.get_column_name() == existing_column.get_column_name() {
320                        column_exists = true;
321                        break;
322                    }
323                }
324                if !column_exists {
325                    let mut renamed_from = "";
326                    if let Some(comment) = &column_def.get_column_spec().comment {
327                        if let Some((_, suffix)) = comment.rsplit_once("renamed_from \"") {
328                            if let Some((prefix, _)) = suffix.split_once('"') {
329                                renamed_from = prefix;
330                            }
331                        }
332                    }
333                    if renamed_from.is_empty() {
334                        db.execute(
335                            TableAlterStatement::new()
336                                .table(self.table.get_table_name().expect("Checked above").clone())
337                                .add_column(column_def.to_owned()),
338                        )
339                        .await?;
340                    } else {
341                        db.execute(
342                            TableAlterStatement::new()
343                                .table(self.table.get_table_name().expect("Checked above").clone())
344                                .rename_column(
345                                    renamed_from.to_owned(),
346                                    column_def.get_column_name(),
347                                ),
348                        )
349                        .await?;
350                    }
351                }
352            }
353            if db.get_database_backend() != DbBackend::Sqlite {
354                for foreign_key in self.table.get_foreign_key_create_stmts().iter() {
355                    let mut key_exists = false;
356                    for existing_key in existing_table.get_foreign_key_create_stmts().iter() {
357                        if compare_foreign_key(foreign_key, existing_key) {
358                            key_exists = true;
359                            break;
360                        }
361                    }
362                    if !key_exists {
363                        db.execute(foreign_key).await?;
364                    }
365                }
366            }
367        } else {
368            db.execute(&self.table).await?;
369        }
370        for stmt in self.indexes.iter() {
371            let mut has_index = false;
372            if let Some(existing_table) = existing_table {
373                for exsiting_index in existing_table.get_indexes() {
374                    if exsiting_index.get_index_spec().get_column_names()
375                        == stmt.get_index_spec().get_column_names()
376                    {
377                        has_index = true;
378                        break;
379                    }
380                }
381            }
382            if !has_index {
383                // shall we do alter table add constraint for unique index?
384                let mut stmt = stmt.clone();
385                stmt.if_not_exists();
386                db.execute(&stmt).await?;
387            }
388        }
389        if let Some(existing_table) = existing_table {
390            // find all unique keys from existing table
391            // if it no longer exist in new schema, drop it
392            for exsiting_index in existing_table.get_indexes() {
393                if exsiting_index.is_unique_key() {
394                    let mut has_index = false;
395                    for stmt in self.indexes.iter() {
396                        if exsiting_index.get_index_spec().get_column_names()
397                            == stmt.get_index_spec().get_column_names()
398                        {
399                            has_index = true;
400                            break;
401                        }
402                    }
403                    if !has_index {
404                        if let Some(drop_existing) = exsiting_index.get_index_spec().get_name() {
405                            db.execute(sea_query::Index::drop().name(drop_existing))
406                                .await?;
407                        }
408                    }
409                }
410            }
411        }
412        Ok(())
413    }
414
415    fn debug_print(
416        &self,
417        f: &mut std::fmt::Formatter<'_>,
418        backend: &DbBackend,
419    ) -> std::fmt::Result {
420        write!(f, "EntitySchemaInfo {{")?;
421        write!(f, " table: {:?}", backend.build(&self.table).to_string())?;
422        write!(f, " enums: [")?;
423        for (i, stmt) in self.enums.iter().enumerate() {
424            if i > 0 {
425                write!(f, ", ")?;
426            }
427            write!(f, "{:?}", backend.build(stmt).to_string())?;
428        }
429        write!(f, " ]")?;
430        write!(f, " indexes: [")?;
431        for (i, stmt) in self.indexes.iter().enumerate() {
432            if i > 0 {
433                write!(f, ", ")?;
434            }
435            write!(f, "{:?}", backend.build(stmt).to_string())?;
436        }
437        write!(f, " ]")?;
438        write!(f, " }}")
439    }
440}
441
442fn get_table_name(table_ref: Option<&TableRef>) -> TableName {
443    match table_ref {
444        Some(TableRef::Table(table_name, _)) => table_name.clone(),
445        None => panic!("Expect TableCreateStatement is properly built"),
446        _ => unreachable!("Unexpected {table_ref:?}"),
447    }
448}
449
450fn compare_foreign_key(a: &ForeignKeyCreateStatement, b: &ForeignKeyCreateStatement) -> bool {
451    let a = a.get_foreign_key();
452    let b = b.get_foreign_key();
453
454    a.get_name() == b.get_name()
455        || (a.get_ref_table() == b.get_ref_table()
456            && a.get_columns() == b.get_columns()
457            && a.get_ref_columns() == b.get_ref_columns())
458}