Skip to main content

sea_orm/schema/
entity.rs

1use crate::{
2    ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, IdenStatic, Iterable,
3    PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6    ColumnDef, DynIden, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement,
7    extension::postgres::{Type, TypeCreateStatement},
8};
9use std::collections::BTreeMap;
10
11impl Schema {
12    /// Creates Postgres enums from an ActiveEnum. See [`TypeCreateStatement`] for more details.
13    /// Returns None if not Postgres.
14    pub fn create_enum_from_active_enum<A>(&self) -> Option<TypeCreateStatement>
15    where
16        A: ActiveEnum,
17    {
18        create_enum_from_active_enum::<A>(self.backend)
19    }
20
21    /// Creates Postgres enums from an Entity. See [`TypeCreateStatement`] for more details.
22    /// Returns empty vec if not Postgres.
23    pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
24    where
25        E: EntityTrait,
26    {
27        create_enum_from_entity(entity, self.backend)
28    }
29
30    /// Creates a table from an Entity. See [TableCreateStatement] for more details.
31    pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
32    where
33        E: EntityTrait,
34    {
35        create_table_from_entity(entity, self.backend)
36    }
37
38    #[doc(hidden)]
39    pub fn create_table_with_index_from_entity<E>(&self, entity: E) -> TableCreateStatement
40    where
41        E: EntityTrait,
42    {
43        let mut table = create_table_from_entity(entity, self.backend);
44        for mut index in create_index_from_entity(entity, self.backend) {
45            table.index(&mut index);
46        }
47        table
48    }
49
50    /// Creates the indexes from an Entity, returning an empty Vec if there are none
51    /// to create. See [IndexCreateStatement] for more details
52    pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
53    where
54        E: EntityTrait,
55    {
56        create_index_from_entity(entity, self.backend)
57    }
58
59    /// Creates a column definition for example to update a table.
60    ///
61    /// ```
62    /// use sea_orm::sea_query::TableAlterStatement;
63    /// use sea_orm::{DbBackend, Schema, Statement};
64    ///
65    /// mod post {
66    ///     use sea_orm::entity::prelude::*;
67    ///
68    ///     #[sea_orm::model]
69    ///     #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
70    ///     #[sea_orm(table_name = "posts")]
71    ///     pub struct Model {
72    ///         #[sea_orm(primary_key)]
73    ///         pub id: u32,
74    ///         pub title: String,
75    ///     }
76    ///
77    ///     impl ActiveModelBehavior for ActiveModel {}
78    /// }
79    ///
80    /// let schema = Schema::new(DbBackend::MySql);
81    ///
82    /// let alter_table: Statement = DbBackend::MySql.build(
83    ///     TableAlterStatement::new()
84    ///         .table(post::Entity)
85    ///         .add_column(&mut schema.get_column_def::<post::Entity>(post::Column::Title)),
86    /// );
87    /// assert_eq!(
88    ///     alter_table.to_string(),
89    ///     "ALTER TABLE `posts` ADD COLUMN `title` varchar(255) NOT NULL"
90    /// );
91    /// ```
92    pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
93    where
94        E: EntityTrait,
95    {
96        column_def_from_entity_column::<E>(column, self.backend)
97    }
98}
99
100pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> Option<TypeCreateStatement>
101where
102    A: ActiveEnum,
103{
104    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
105        return None;
106    }
107    let col_def = A::db_type();
108    let col_type = col_def.get_column_type();
109    create_enum_from_column_type(col_type)
110}
111
112pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> Option<TypeCreateStatement> {
113    let (name, values) = match col_type {
114        ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
115        _ => return None,
116    };
117    Some(Type::create().as_enum(name).values(values).to_owned())
118}
119
120#[allow(clippy::needless_borrow)]
121pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
122where
123    E: EntityTrait,
124{
125    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
126        return Vec::new();
127    }
128    let mut vec = Vec::new();
129    for col in E::Column::iter() {
130        let col_def = col.def();
131        let col_type = col_def.get_column_type();
132        if !matches!(col_type, ColumnType::Enum { .. }) {
133            continue;
134        }
135        if let Some(stmt) = create_enum_from_column_type(&col_type) {
136            vec.push(stmt);
137        }
138    }
139    vec
140}
141
142pub(crate) fn create_index_from_entity<E>(
143    entity: E,
144    _backend: DbBackend,
145) -> Vec<IndexCreateStatement>
146where
147    E: EntityTrait,
148{
149    let mut indexes = Vec::new();
150    let mut unique_keys: BTreeMap<String, Vec<DynIden>> = Default::default();
151
152    for column in E::Column::iter() {
153        let column_def = column.def();
154
155        if column_def.indexed && !column_def.unique {
156            let stmt = Index::create()
157                .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
158                .table(entity)
159                .col(column)
160                .take();
161            indexes.push(stmt);
162        }
163
164        if let Some(key) = column_def.unique_key {
165            unique_keys.entry(key).or_default().push(SeaRc::new(column));
166        }
167    }
168
169    for (key, cols) in unique_keys {
170        let mut stmt = Index::create()
171            .name(format!("idx-{}-{}", entity.to_string(), key))
172            .table(entity)
173            .unique()
174            .take();
175        for col in cols {
176            stmt.col(col);
177        }
178        indexes.push(stmt);
179    }
180
181    indexes
182}
183
184pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
185where
186    E: EntityTrait,
187{
188    let mut stmt = TableCreateStatement::new();
189
190    if let Some(comment) = entity.comment() {
191        stmt.comment(comment);
192    }
193
194    for column in E::Column::iter() {
195        let mut column_def = column_def_from_entity_column::<E>(column, backend);
196        stmt.col(&mut column_def);
197    }
198
199    if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
200        let mut idx_pk = Index::create();
201        for primary_key in E::PrimaryKey::iter() {
202            idx_pk.col(primary_key);
203        }
204        stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
205    }
206
207    for relation in E::Relation::iter() {
208        let relation = relation.def();
209        if relation.is_owner || relation.skip_fk {
210            continue;
211        }
212        stmt.foreign_key(&mut relation.into());
213    }
214
215    stmt.table(entity.table_ref()).take()
216}
217
218fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
219where
220    E: EntityTrait,
221{
222    let orm_column_def = column.def();
223    let types = match &orm_column_def.col_type {
224        ColumnType::Enum { name, variants } => match backend {
225            DbBackend::MySql => {
226                let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
227                ColumnType::custom(format!("ENUM('{}')", variants.join("', '")))
228            }
229            DbBackend::Postgres => ColumnType::Custom(name.clone()),
230            DbBackend::Sqlite => orm_column_def.col_type,
231        },
232        _ => orm_column_def.col_type,
233    };
234    let mut column_def = ColumnDef::new_with_type(column, types);
235    if !orm_column_def.null {
236        column_def.not_null();
237    }
238    if orm_column_def.unique {
239        column_def.unique_key();
240    }
241    if let Some(default) = orm_column_def.default {
242        column_def.default(default);
243    }
244    if let Some(comment) = &orm_column_def.comment {
245        column_def.comment(comment);
246    }
247    if let Some(extra) = &orm_column_def.extra {
248        column_def.extra(extra);
249    }
250    match (&orm_column_def.renamed_from, &orm_column_def.comment) {
251        (Some(renamed_from), Some(comment)) => {
252            column_def.comment(format!("{comment}; renamed_from \"{renamed_from}\""));
253        }
254        (Some(renamed_from), None) => {
255            column_def.comment(format!("renamed_from \"{renamed_from}\""));
256        }
257        (None, _) => {}
258    }
259    for primary_key in E::PrimaryKey::iter() {
260        if column.as_str() == primary_key.into_column().as_str() {
261            if E::PrimaryKey::auto_increment() {
262                column_def.auto_increment();
263            }
264            if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
265                column_def.primary_key();
266            }
267        }
268    }
269    column_def
270}
271
272#[cfg(test)]
273mod tests {
274    use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*};
275    use pretty_assertions::assert_eq;
276
277    #[test]
278    fn test_create_table_from_entity_table_ref() {
279        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
280            let schema = Schema::new(builder);
281            assert_eq!(
282                builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
283                builder.build(
284                    &get_cake_filling_price_stmt()
285                        .table(CakeFillingPrice.table_ref())
286                        .to_owned()
287                )
288            );
289        }
290    }
291
292    fn get_cake_filling_price_stmt() -> TableCreateStatement {
293        Table::create()
294            .col(
295                ColumnDef::new(cake_filling_price::Column::CakeId)
296                    .integer()
297                    .not_null(),
298            )
299            .col(
300                ColumnDef::new(cake_filling_price::Column::FillingId)
301                    .integer()
302                    .not_null(),
303            )
304            .col(
305                ColumnDef::new(cake_filling_price::Column::Price)
306                    .decimal()
307                    .not_null()
308                    .extra("CHECK (price > 0)"),
309            )
310            .primary_key(
311                Index::create()
312                    .name("pk-cake_filling_price")
313                    .col(cake_filling_price::Column::CakeId)
314                    .col(cake_filling_price::Column::FillingId)
315                    .primary(),
316            )
317            .foreign_key(
318                ForeignKeyCreateStatement::new()
319                    .name("fk-cake_filling_price-cake_id-filling_id")
320                    .from_tbl(CakeFillingPrice)
321                    .from_col(cake_filling_price::Column::CakeId)
322                    .from_col(cake_filling_price::Column::FillingId)
323                    .to_tbl(CakeFilling)
324                    .to_col(cake_filling::Column::CakeId)
325                    .to_col(cake_filling::Column::FillingId),
326            )
327            .to_owned()
328    }
329
330    #[test]
331    fn test_create_index_from_entity_table_ref() {
332        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
333            let schema = Schema::new(builder);
334
335            assert_eq!(
336                builder.build(&schema.create_table_from_entity(indexes::Entity)),
337                builder.build(
338                    &get_indexes_table_stmt()
339                        .table(indexes::Entity.table_ref())
340                        .to_owned()
341                )
342            );
343
344            let stmts = schema.create_index_from_entity(indexes::Entity);
345            assert_eq!(stmts.len(), 2);
346
347            let idx: IndexCreateStatement = Index::create()
348                .name("idx-indexes-index1_attr")
349                .table(indexes::Entity)
350                .col(indexes::Column::Index1Attr)
351                .to_owned();
352            assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
353
354            let idx: IndexCreateStatement = Index::create()
355                .name("idx-indexes-my_unique")
356                .table(indexes::Entity)
357                .col(indexes::Column::UniqueKeyA)
358                .col(indexes::Column::UniqueKeyB)
359                .unique()
360                .take();
361            assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
362        }
363    }
364
365    fn get_indexes_table_stmt() -> TableCreateStatement {
366        Table::create()
367            .col(
368                ColumnDef::new(indexes::Column::IndexesId)
369                    .integer()
370                    .not_null()
371                    .auto_increment()
372                    .primary_key(),
373            )
374            .col(
375                ColumnDef::new(indexes::Column::UniqueAttr)
376                    .integer()
377                    .not_null()
378                    .unique_key(),
379            )
380            .col(
381                ColumnDef::new(indexes::Column::Index1Attr)
382                    .integer()
383                    .not_null(),
384            )
385            .col(
386                ColumnDef::new(indexes::Column::Index2Attr)
387                    .integer()
388                    .not_null()
389                    .unique_key(),
390            )
391            .col(
392                ColumnDef::new(indexes::Column::UniqueKeyA)
393                    .string()
394                    .not_null(),
395            )
396            .col(
397                ColumnDef::new(indexes::Column::UniqueKeyB)
398                    .string()
399                    .not_null(),
400            )
401            .to_owned()
402    }
403}