Skip to main content

sea_orm/schema/
entity.rs

1use crate::{
2    ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, IdenStatic, Iterable,
3    PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6    ColumnDef, DynIden, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement, TableName,
7    TableRef,
8    extension::postgres::{Type, TypeCreateStatement},
9};
10use std::collections::BTreeMap;
11
12impl Schema {
13    /// Creates Postgres enums from an ActiveEnum. See [`TypeCreateStatement`] for more details.
14    /// Returns None if not Postgres.
15    pub fn create_enum_from_active_enum<A>(&self) -> Option<TypeCreateStatement>
16    where
17        A: ActiveEnum,
18    {
19        create_enum_from_active_enum::<A>(self.backend)
20    }
21
22    /// Creates Postgres enums from an Entity. See [`TypeCreateStatement`] for more details.
23    /// Returns empty vec if not Postgres.
24    pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
25    where
26        E: EntityTrait,
27    {
28        create_enum_from_entity(entity, self.backend)
29    }
30
31    /// Creates a table from an Entity. See [TableCreateStatement] for more details.
32    pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
33    where
34        E: EntityTrait,
35    {
36        create_table_from_entity(entity, self.backend)
37    }
38
39    #[doc(hidden)]
40    pub fn create_table_with_index_from_entity<E>(&self, entity: E) -> TableCreateStatement
41    where
42        E: EntityTrait,
43    {
44        let mut table = create_table_from_entity(entity, self.backend);
45        for mut index in create_index_from_entity(entity, self.backend) {
46            table.index(&mut index);
47        }
48        table
49    }
50
51    /// Creates the indexes from an Entity, returning an empty Vec if there are none
52    /// to create. See [IndexCreateStatement] for more details
53    pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
54    where
55        E: EntityTrait,
56    {
57        create_index_from_entity(entity, self.backend)
58    }
59
60    /// Creates a column definition for example to update a table.
61    ///
62    /// ```
63    /// use sea_orm::sea_query::TableAlterStatement;
64    /// use sea_orm::{DbBackend, Schema, Statement};
65    ///
66    /// mod post {
67    ///     use sea_orm::entity::prelude::*;
68    ///
69    ///     #[sea_orm::model]
70    ///     #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
71    ///     #[sea_orm(table_name = "posts")]
72    ///     pub struct Model {
73    ///         #[sea_orm(primary_key)]
74    ///         pub id: u32,
75    ///         pub title: String,
76    ///     }
77    ///
78    ///     impl ActiveModelBehavior for ActiveModel {}
79    /// }
80    ///
81    /// let schema = Schema::new(DbBackend::MySql);
82    ///
83    /// let alter_table: Statement = DbBackend::MySql.build(
84    ///     TableAlterStatement::new()
85    ///         .table(post::Entity)
86    ///         .add_column(&mut schema.get_column_def::<post::Entity>(post::Column::Title)),
87    /// );
88    /// assert_eq!(
89    ///     alter_table.to_string(),
90    ///     "ALTER TABLE `posts` ADD COLUMN `title` varchar(255) NOT NULL"
91    /// );
92    /// ```
93    pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
94    where
95        E: EntityTrait,
96    {
97        column_def_from_entity_column::<E>(column, self.backend)
98    }
99}
100
101pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> Option<TypeCreateStatement>
102where
103    A: ActiveEnum,
104{
105    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
106        return None;
107    }
108    let col_def = A::db_type();
109    let col_type = col_def.get_column_type();
110    create_enum_from_column_type(col_type)
111}
112
113pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> Option<TypeCreateStatement> {
114    let (name, values) = match col_type {
115        ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
116        _ => return None,
117    };
118    Some(Type::create().as_enum(name).values(values).to_owned())
119}
120
121#[allow(clippy::needless_borrow)]
122pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
123where
124    E: EntityTrait,
125{
126    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
127        return Vec::new();
128    }
129    let mut vec = Vec::new();
130    for col in E::Column::iter() {
131        let col_def = col.def();
132        let col_type = col_def.get_column_type();
133        if !matches!(col_type, ColumnType::Enum { .. }) {
134            continue;
135        }
136        if let Some(stmt) = create_enum_from_column_type(&col_type) {
137            vec.push(stmt);
138        }
139    }
140    vec
141}
142
143pub(crate) fn create_index_from_entity<E>(
144    entity: E,
145    backend: DbBackend,
146) -> Vec<IndexCreateStatement>
147where
148    E: EntityTrait,
149{
150    let mut indexes = Vec::new();
151    let mut unique_keys: BTreeMap<String, Vec<DynIden>> = Default::default();
152
153    for column in E::Column::iter() {
154        let column_def = column.def();
155
156        if column_def.indexed && !column_def.unique {
157            let stmt = Index::create()
158                .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
159                .table(index_table_ref(entity.table_ref(), backend))
160                .col(column)
161                .take();
162            indexes.push(stmt);
163        }
164
165        if let Some(key) = column_def.unique_key {
166            unique_keys.entry(key).or_default().push(SeaRc::new(column));
167        }
168    }
169
170    for (key, cols) in unique_keys {
171        let mut stmt = Index::create()
172            .name(format!("idx-{}-{}", entity.to_string(), key))
173            .table(index_table_ref(entity.table_ref(), backend))
174            .unique()
175            .take();
176        for col in cols {
177            stmt.col(col);
178        }
179        indexes.push(stmt);
180    }
181
182    indexes
183}
184
185/// Build the table reference used for a generated index.
186///
187/// PostgreSQL accepts a schema-qualified index target
188/// (`CREATE INDEX ... ON "schema"."table"`), so a `schema_name` qualifier is
189/// preserved. SeaQuery's MySQL and SQLite index builders accept only a bare
190/// table name and panic on a qualified one; their generated index is implicitly
191/// scoped to the table's database/schema anyway, so the qualifier is stripped.
192pub(crate) fn index_table_ref(table_ref: TableRef, backend: DbBackend) -> TableRef {
193    match backend {
194        DbBackend::Postgres => table_ref,
195        DbBackend::MySql | DbBackend::Sqlite => match table_ref {
196            TableRef::Table(TableName(Some(_), table), alias) => {
197                TableRef::Table(TableName(None, table), alias)
198            }
199            other => other,
200        },
201    }
202}
203
204pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
205where
206    E: EntityTrait,
207{
208    let mut stmt = TableCreateStatement::new();
209
210    if let Some(comment) = entity.comment() {
211        stmt.comment(comment);
212    }
213
214    for column in E::Column::iter() {
215        let mut column_def = column_def_from_entity_column::<E>(column, backend);
216        stmt.col(&mut column_def);
217    }
218
219    if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
220        let mut idx_pk = Index::create();
221        for primary_key in E::PrimaryKey::iter() {
222            idx_pk.col(primary_key);
223        }
224        stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
225    }
226
227    for relation in E::Relation::iter() {
228        let relation = relation.def();
229        if relation.is_owner || relation.skip_fk {
230            continue;
231        }
232        stmt.foreign_key(&mut relation.into());
233    }
234
235    stmt.table(entity.table_ref()).take()
236}
237
238fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
239where
240    E: EntityTrait,
241{
242    let orm_column_def = column.def();
243    let types = match &orm_column_def.col_type {
244        ColumnType::Enum { name, variants } => match backend {
245            DbBackend::MySql => {
246                let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
247                ColumnType::custom(format!("ENUM('{}')", variants.join("', '")))
248            }
249            DbBackend::Postgres => ColumnType::Custom(name.clone()),
250            DbBackend::Sqlite => orm_column_def.col_type,
251        },
252        _ => orm_column_def.col_type,
253    };
254    let mut column_def = ColumnDef::new_with_type(column, types);
255    if !orm_column_def.null {
256        column_def.not_null();
257    }
258    if orm_column_def.unique {
259        column_def.unique_key();
260    }
261    if let Some(default) = orm_column_def.default {
262        column_def.default(default);
263    }
264    if let Some(comment) = &orm_column_def.comment {
265        column_def.comment(comment);
266    }
267    if let Some(extra) = &orm_column_def.extra {
268        column_def.extra(extra);
269    }
270    match (&orm_column_def.renamed_from, &orm_column_def.comment) {
271        (Some(renamed_from), Some(comment)) => {
272            column_def.comment(format!("{comment}; renamed_from \"{renamed_from}\""));
273        }
274        (Some(renamed_from), None) => {
275            column_def.comment(format!("renamed_from \"{renamed_from}\""));
276        }
277        (None, _) => {}
278    }
279    for primary_key in E::PrimaryKey::iter() {
280        if column.as_str() == primary_key.into_column().as_str() {
281            if E::PrimaryKey::auto_increment() {
282                column_def.auto_increment();
283            }
284            if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
285                column_def.primary_key();
286            }
287        }
288    }
289    column_def
290}
291
292#[cfg(test)]
293mod tests {
294    use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*};
295    use pretty_assertions::assert_eq;
296
297    mod custom_schema_indexes {
298        use crate as sea_orm;
299        use crate::entity::prelude::*;
300
301        #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
302        #[sea_orm(schema_name = "sys", table_name = "app_user")]
303        pub struct Model {
304            #[sea_orm(primary_key)]
305            pub id: i32,
306            #[sea_orm(indexed)]
307            pub email: String,
308            #[sea_orm(unique_key = "tenant_name")]
309            pub tenant_id: i32,
310            #[sea_orm(unique_key = "tenant_name")]
311            pub name: String,
312        }
313
314        #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
315        pub enum Relation {}
316
317        impl ActiveModelBehavior for ActiveModel {}
318    }
319
320    #[test]
321    fn test_create_table_from_entity_table_ref() {
322        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
323            let schema = Schema::new(builder);
324            assert_eq!(
325                builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
326                builder.build(
327                    &get_cake_filling_price_stmt()
328                        .table(CakeFillingPrice.table_ref())
329                        .to_owned()
330                )
331            );
332        }
333    }
334
335    fn get_cake_filling_price_stmt() -> TableCreateStatement {
336        Table::create()
337            .col(
338                ColumnDef::new(cake_filling_price::Column::CakeId)
339                    .integer()
340                    .not_null(),
341            )
342            .col(
343                ColumnDef::new(cake_filling_price::Column::FillingId)
344                    .integer()
345                    .not_null(),
346            )
347            .col(
348                ColumnDef::new(cake_filling_price::Column::Price)
349                    .decimal()
350                    .not_null()
351                    .extra("CHECK (price > 0)"),
352            )
353            .primary_key(
354                Index::create()
355                    .name("pk-cake_filling_price")
356                    .col(cake_filling_price::Column::CakeId)
357                    .col(cake_filling_price::Column::FillingId)
358                    .primary(),
359            )
360            .foreign_key(
361                ForeignKeyCreateStatement::new()
362                    .name("fk-cake_filling_price-cake_id-filling_id")
363                    .from_tbl(CakeFillingPrice)
364                    .from_col(cake_filling_price::Column::CakeId)
365                    .from_col(cake_filling_price::Column::FillingId)
366                    .to_tbl(CakeFilling)
367                    .to_col(cake_filling::Column::CakeId)
368                    .to_col(cake_filling::Column::FillingId),
369            )
370            .to_owned()
371    }
372
373    #[test]
374    fn test_create_index_from_entity_table_ref() {
375        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
376            let schema = Schema::new(builder);
377
378            assert_eq!(
379                builder.build(&schema.create_table_from_entity(indexes::Entity)),
380                builder.build(
381                    &get_indexes_table_stmt()
382                        .table(indexes::Entity.table_ref())
383                        .to_owned()
384                )
385            );
386
387            let stmts = schema.create_index_from_entity(indexes::Entity);
388            assert_eq!(stmts.len(), 2);
389
390            let index_table = match builder {
391                DbBackend::Postgres => indexes::Entity.table_ref(),
392                DbBackend::MySql | DbBackend::Sqlite => indexes::Entity.into_table_ref(),
393            };
394            let idx: IndexCreateStatement = Index::create()
395                .name("idx-indexes-index1_attr")
396                .table(index_table)
397                .col(indexes::Column::Index1Attr)
398                .to_owned();
399            assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
400
401            let index_table = match builder {
402                DbBackend::Postgres => indexes::Entity.table_ref(),
403                DbBackend::MySql | DbBackend::Sqlite => indexes::Entity.into_table_ref(),
404            };
405            let idx: IndexCreateStatement = Index::create()
406                .name("idx-indexes-my_unique")
407                .table(index_table)
408                .col(indexes::Column::UniqueKeyA)
409                .col(indexes::Column::UniqueKeyB)
410                .unique()
411                .take();
412            assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
413        }
414    }
415
416    #[test]
417    fn test_create_index_from_entity_non_default_schema_table_ref() {
418        let builder = DbBackend::Postgres;
419        let schema = Schema::new(builder);
420        let stmts = schema.create_index_from_entity(custom_schema_indexes::Entity);
421        assert_eq!(stmts.len(), 2);
422
423        let idx: IndexCreateStatement = Index::create()
424            .name("idx-app_user-email")
425            .table(custom_schema_indexes::Entity.table_ref())
426            .col(custom_schema_indexes::Column::Email)
427            .to_owned();
428        assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
429
430        let idx: IndexCreateStatement = Index::create()
431            .name("idx-app_user-tenant_name")
432            .table(custom_schema_indexes::Entity.table_ref())
433            .col(custom_schema_indexes::Column::TenantId)
434            .col(custom_schema_indexes::Column::Name)
435            .unique()
436            .take();
437        assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
438
439        // The generated DDL targets the schema-qualified table.
440        assert!(builder.build(&stmts[0]).sql.contains(r#""sys"."app_user""#));
441    }
442
443    // Regression guard for the SeaQuery MySQL/SQLite index builders, which panic
444    // on a schema-qualified table reference. `create_index_from_entity` must
445    // strip the `schema_name` qualifier on those backends, so generation neither
446    // panics nor emits a qualified target. See `index_table_ref`.
447    #[test]
448    fn test_create_index_from_entity_non_default_schema_strips_schema_on_mysql_sqlite() {
449        for builder in [DbBackend::MySql, DbBackend::Sqlite] {
450            let schema = Schema::new(builder);
451            // Must not panic for a `schema_name` entity on MySQL/SQLite.
452            let stmts = schema.create_index_from_entity(custom_schema_indexes::Entity);
453            assert_eq!(stmts.len(), 2);
454
455            for stmt in &stmts {
456                let sql = builder.build(stmt).sql;
457                assert!(
458                    sql.contains("app_user"),
459                    "{builder:?} index should target the table: {sql}"
460                );
461                assert!(
462                    !sql.contains("sys"),
463                    "{builder:?} index should not be schema-qualified: {sql}"
464                );
465            }
466        }
467    }
468
469    fn get_indexes_table_stmt() -> TableCreateStatement {
470        Table::create()
471            .col(
472                ColumnDef::new(indexes::Column::IndexesId)
473                    .integer()
474                    .not_null()
475                    .auto_increment()
476                    .primary_key(),
477            )
478            .col(
479                ColumnDef::new(indexes::Column::UniqueAttr)
480                    .integer()
481                    .not_null()
482                    .unique_key(),
483            )
484            .col(
485                ColumnDef::new(indexes::Column::Index1Attr)
486                    .integer()
487                    .not_null(),
488            )
489            .col(
490                ColumnDef::new(indexes::Column::Index2Attr)
491                    .integer()
492                    .not_null()
493                    .unique_key(),
494            )
495            .col(
496                ColumnDef::new(indexes::Column::UniqueKeyA)
497                    .string()
498                    .not_null(),
499            )
500            .col(
501                ColumnDef::new(indexes::Column::UniqueKeyB)
502                    .string()
503                    .not_null(),
504            )
505            .to_owned()
506    }
507}