sea_orm/schema/
entity.rs

1use crate::{
2    ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, IdenStatic, Iterable,
3    PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6    ColumnDef, DynIden, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement,
7    extension::postgres::{Type, TypeCreateStatement},
8};
9use std::collections::BTreeMap;
10
11impl Schema {
12    /// Creates Postgres enums from an ActiveEnum. See [`TypeCreateStatement`] for more details.
13    /// Returns None if not Postgres.
14    pub fn create_enum_from_active_enum<A>(&self) -> Option<TypeCreateStatement>
15    where
16        A: ActiveEnum,
17    {
18        create_enum_from_active_enum::<A>(self.backend)
19    }
20
21    /// Creates Postgres enums from an Entity. See [`TypeCreateStatement`] for more details.
22    /// Returns empty vec if not Postgres.
23    pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
24    where
25        E: EntityTrait,
26    {
27        create_enum_from_entity(entity, self.backend)
28    }
29
30    /// Creates a table from an Entity. See [TableCreateStatement] for more details.
31    pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
32    where
33        E: EntityTrait,
34    {
35        create_table_from_entity(entity, self.backend)
36    }
37
38    #[doc(hidden)]
39    pub fn create_table_with_index_from_entity<E>(&self, entity: E) -> TableCreateStatement
40    where
41        E: EntityTrait,
42    {
43        let mut table = create_table_from_entity(entity, self.backend);
44        for mut index in create_index_from_entity(entity, self.backend) {
45            table.index(&mut index);
46        }
47        table
48    }
49
50    /// Creates the indexes from an Entity, returning an empty Vec if there are none
51    /// to create. See [IndexCreateStatement] for more details
52    pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
53    where
54        E: EntityTrait,
55    {
56        create_index_from_entity(entity, self.backend)
57    }
58
59    /// Creates a column definition for example to update a table.
60    ///
61    /// ```
62    /// use sea_orm::sea_query::TableAlterStatement;
63    /// use sea_orm::{DbBackend, Schema, Statement};
64    ///
65    /// mod post {
66    ///     use sea_orm::entity::prelude::*;
67    ///
68    ///     #[sea_orm::model]
69    ///     #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
70    ///     #[sea_orm(table_name = "posts")]
71    ///     pub struct Model {
72    ///         #[sea_orm(primary_key)]
73    ///         pub id: u32,
74    ///         pub title: String,
75    ///     }
76    ///
77    ///     impl ActiveModelBehavior for ActiveModel {}
78    /// }
79    ///
80    /// let schema = Schema::new(DbBackend::MySql);
81    ///
82    /// let alter_table: Statement = DbBackend::MySql.build(
83    ///     TableAlterStatement::new()
84    ///         .table(post::Entity)
85    ///         .add_column(&mut schema.get_column_def::<post::Entity>(post::Column::Title)),
86    /// );
87    /// assert_eq!(
88    ///     alter_table.to_string(),
89    ///     "ALTER TABLE `posts` ADD COLUMN `title` varchar(255) NOT NULL"
90    /// );
91    /// ```
92    pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
93    where
94        E: EntityTrait,
95    {
96        column_def_from_entity_column::<E>(column, self.backend)
97    }
98}
99
100pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> Option<TypeCreateStatement>
101where
102    A: ActiveEnum,
103{
104    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
105        return None;
106    }
107    let col_def = A::db_type();
108    let col_type = col_def.get_column_type();
109    create_enum_from_column_type(col_type)
110}
111
112pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> Option<TypeCreateStatement> {
113    let (name, values) = match col_type {
114        ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
115        _ => return None,
116    };
117    Some(Type::create().as_enum(name).values(values).to_owned())
118}
119
120#[allow(clippy::needless_borrow)]
121pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
122where
123    E: EntityTrait,
124{
125    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
126        return Vec::new();
127    }
128    let mut vec = Vec::new();
129    for col in E::Column::iter() {
130        let col_def = col.def();
131        let col_type = col_def.get_column_type();
132        if !matches!(col_type, ColumnType::Enum { .. }) {
133            continue;
134        }
135        if let Some(stmt) = create_enum_from_column_type(&col_type) {
136            vec.push(stmt);
137        }
138    }
139    vec
140}
141
142pub(crate) fn create_index_from_entity<E>(
143    entity: E,
144    _backend: DbBackend,
145) -> Vec<IndexCreateStatement>
146where
147    E: EntityTrait,
148{
149    let mut indexes = Vec::new();
150    let mut unique_keys: BTreeMap<String, Vec<DynIden>> = Default::default();
151
152    for column in E::Column::iter() {
153        let column_def = column.def();
154
155        if column_def.indexed || column_def.unique {
156            let mut stmt = Index::create()
157                .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
158                .table(entity)
159                .col(column)
160                .take();
161            if column_def.unique {
162                stmt.unique();
163            }
164            indexes.push(stmt);
165        }
166
167        if let Some(key) = column_def.unique_key {
168            unique_keys.entry(key).or_default().push(SeaRc::new(column));
169        }
170    }
171
172    for (key, cols) in unique_keys {
173        let mut stmt = Index::create()
174            .name(format!("idx-{}-{}", entity.to_string(), key))
175            .table(entity)
176            .unique()
177            .take();
178        for col in cols {
179            stmt.col(col);
180        }
181        indexes.push(stmt);
182    }
183
184    indexes
185}
186
187pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
188where
189    E: EntityTrait,
190{
191    let mut stmt = TableCreateStatement::new();
192
193    if let Some(comment) = entity.comment() {
194        stmt.comment(comment);
195    }
196
197    for column in E::Column::iter() {
198        let mut column_def = column_def_from_entity_column::<E>(column, backend);
199        stmt.col(&mut column_def);
200    }
201
202    if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
203        let mut idx_pk = Index::create();
204        for primary_key in E::PrimaryKey::iter() {
205            idx_pk.col(primary_key);
206        }
207        stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
208    }
209
210    for relation in E::Relation::iter() {
211        let relation = relation.def();
212        if relation.is_owner || relation.skip_fk {
213            continue;
214        }
215        stmt.foreign_key(&mut relation.into());
216    }
217
218    stmt.table(entity.table_ref()).take()
219}
220
221fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
222where
223    E: EntityTrait,
224{
225    let orm_column_def = column.def();
226    let types = match &orm_column_def.col_type {
227        ColumnType::Enum { name, variants } => match backend {
228            DbBackend::MySql => {
229                let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
230                ColumnType::custom(format!("ENUM('{}')", variants.join("', '")))
231            }
232            DbBackend::Postgres => ColumnType::Custom(name.clone()),
233            DbBackend::Sqlite => orm_column_def.col_type,
234        },
235        _ => orm_column_def.col_type,
236    };
237    let mut column_def = ColumnDef::new_with_type(column, types);
238    if !orm_column_def.null {
239        column_def.not_null();
240    }
241    if orm_column_def.unique {
242        column_def.unique_key();
243    }
244    if let Some(default) = orm_column_def.default {
245        column_def.default(default);
246    }
247    if let Some(comment) = &orm_column_def.comment {
248        column_def.comment(comment);
249    }
250    if let Some(extra) = &orm_column_def.extra {
251        column_def.extra(extra);
252    }
253    match (&orm_column_def.renamed_from, &orm_column_def.comment) {
254        (Some(renamed_from), Some(comment)) => {
255            column_def.comment(format!("{comment}; renamed_from \"{renamed_from}\""));
256        }
257        (Some(renamed_from), None) => {
258            column_def.comment(format!("renamed_from \"{renamed_from}\""));
259        }
260        (None, _) => {}
261    }
262    for primary_key in E::PrimaryKey::iter() {
263        if column.as_str() == primary_key.into_column().as_str() {
264            if E::PrimaryKey::auto_increment() {
265                column_def.auto_increment();
266            }
267            if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
268                column_def.primary_key();
269            }
270        }
271    }
272    column_def
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*};
278    use pretty_assertions::assert_eq;
279
280    #[test]
281    fn test_create_table_from_entity_table_ref() {
282        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
283            let schema = Schema::new(builder);
284            assert_eq!(
285                builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
286                builder.build(
287                    &get_cake_filling_price_stmt()
288                        .table(CakeFillingPrice.table_ref())
289                        .to_owned()
290                )
291            );
292        }
293    }
294
295    fn get_cake_filling_price_stmt() -> TableCreateStatement {
296        Table::create()
297            .col(
298                ColumnDef::new(cake_filling_price::Column::CakeId)
299                    .integer()
300                    .not_null(),
301            )
302            .col(
303                ColumnDef::new(cake_filling_price::Column::FillingId)
304                    .integer()
305                    .not_null(),
306            )
307            .col(
308                ColumnDef::new(cake_filling_price::Column::Price)
309                    .decimal()
310                    .not_null()
311                    .extra("CHECK (price > 0)"),
312            )
313            .primary_key(
314                Index::create()
315                    .name("pk-cake_filling_price")
316                    .col(cake_filling_price::Column::CakeId)
317                    .col(cake_filling_price::Column::FillingId)
318                    .primary(),
319            )
320            .foreign_key(
321                ForeignKeyCreateStatement::new()
322                    .name("fk-cake_filling_price-cake_id-filling_id")
323                    .from_tbl(CakeFillingPrice)
324                    .from_col(cake_filling_price::Column::CakeId)
325                    .from_col(cake_filling_price::Column::FillingId)
326                    .to_tbl(CakeFilling)
327                    .to_col(cake_filling::Column::CakeId)
328                    .to_col(cake_filling::Column::FillingId),
329            )
330            .to_owned()
331    }
332
333    #[test]
334    fn test_create_index_from_entity_table_ref() {
335        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
336            let schema = Schema::new(builder);
337
338            assert_eq!(
339                builder.build(&schema.create_table_from_entity(indexes::Entity)),
340                builder.build(
341                    &get_indexes_table_stmt()
342                        .table(indexes::Entity.table_ref())
343                        .to_owned()
344                )
345            );
346
347            let stmts = schema.create_index_from_entity(indexes::Entity);
348            assert_eq!(stmts.len(), 4);
349
350            let idx: IndexCreateStatement = Index::create()
351                .name("idx-indexes-unique_attr")
352                .table(indexes::Entity)
353                .col(indexes::Column::UniqueAttr)
354                .unique()
355                .to_owned();
356            assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
357
358            let idx: IndexCreateStatement = Index::create()
359                .name("idx-indexes-index1_attr")
360                .table(indexes::Entity)
361                .col(indexes::Column::Index1Attr)
362                .to_owned();
363            assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
364
365            let idx: IndexCreateStatement = Index::create()
366                .name("idx-indexes-index2_attr")
367                .table(indexes::Entity)
368                .col(indexes::Column::Index2Attr)
369                .unique()
370                .take();
371            assert_eq!(builder.build(&stmts[2]), builder.build(&idx));
372
373            let idx: IndexCreateStatement = Index::create()
374                .name("idx-indexes-my_unique")
375                .table(indexes::Entity)
376                .col(indexes::Column::UniqueKeyA)
377                .col(indexes::Column::UniqueKeyB)
378                .unique()
379                .take();
380            assert_eq!(builder.build(&stmts[3]), builder.build(&idx));
381        }
382    }
383
384    fn get_indexes_table_stmt() -> TableCreateStatement {
385        Table::create()
386            .col(
387                ColumnDef::new(indexes::Column::IndexesId)
388                    .integer()
389                    .not_null()
390                    .auto_increment()
391                    .primary_key(),
392            )
393            .col(
394                ColumnDef::new(indexes::Column::UniqueAttr)
395                    .integer()
396                    .not_null()
397                    .unique_key(),
398            )
399            .col(
400                ColumnDef::new(indexes::Column::Index1Attr)
401                    .integer()
402                    .not_null(),
403            )
404            .col(
405                ColumnDef::new(indexes::Column::Index2Attr)
406                    .integer()
407                    .not_null()
408                    .unique_key(),
409            )
410            .col(
411                ColumnDef::new(indexes::Column::UniqueKeyA)
412                    .string()
413                    .not_null(),
414            )
415            .col(
416                ColumnDef::new(indexes::Column::UniqueKeyB)
417                    .string()
418                    .not_null(),
419            )
420            .to_owned()
421    }
422}