sea_orm/schema/
entity.rs

1use crate::{
2    ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, Iterable, PrimaryKeyArity,
3    PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6    extension::postgres::{Type, TypeCreateStatement},
7    ColumnDef, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement,
8};
9
10impl Schema {
11    /// Creates Postgres enums from an ActiveEnum. See [TypeCreateStatement] for more details
12    pub fn create_enum_from_active_enum<A>(&self) -> TypeCreateStatement
13    where
14        A: ActiveEnum,
15    {
16        create_enum_from_active_enum::<A>(self.backend)
17    }
18
19    /// Creates Postgres enums from an Entity. See [TypeCreateStatement] for more details
20    pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
21    where
22        E: EntityTrait,
23    {
24        create_enum_from_entity(entity, self.backend)
25    }
26
27    /// Creates a table from an Entity. See [TableCreateStatement] for more details.
28    pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
29    where
30        E: EntityTrait,
31    {
32        create_table_from_entity(entity, self.backend)
33    }
34
35    /// Creates the indexes from an Entity, returning an empty Vec if there are none
36    /// to create. See [IndexCreateStatement] for more details
37    pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
38    where
39        E: EntityTrait,
40    {
41        create_index_from_entity(entity, self.backend)
42    }
43
44    /// Creates a column definition for example to update a table.
45    ///
46    /// ```
47    /// use crate::sea_orm::IdenStatic;
48    /// use sea_orm::{
49    ///     ActiveModelBehavior, ColumnDef, ColumnTrait, ColumnType, DbBackend, EntityName,
50    ///     EntityTrait, EnumIter, PrimaryKeyTrait, RelationDef, RelationTrait, Schema,
51    /// };
52    /// use sea_orm_macros::{DeriveEntityModel, DerivePrimaryKey};
53    /// use sea_query::{MysqlQueryBuilder, TableAlterStatement};
54    ///
55    /// #[derive(Clone, Debug, PartialEq, DeriveEntityModel)]
56    /// #[sea_orm(table_name = "posts")]
57    /// pub struct Model {
58    ///     #[sea_orm(primary_key)]
59    ///     pub id: u32,
60    ///     pub title: String,
61    /// }
62    ///
63    /// #[derive(Copy, Clone, Debug, EnumIter)]
64    /// pub enum Relation {}
65    ///
66    /// impl RelationTrait for Relation {
67    ///     fn def(&self) -> RelationDef {
68    ///         panic!("No RelationDef")
69    ///     }
70    /// }
71    /// impl ActiveModelBehavior for ActiveModel {}
72    ///
73    /// let schema = Schema::new(DbBackend::MySql);
74    ///
75    /// let mut alter_table = TableAlterStatement::new()
76    ///     .table(Entity)
77    ///     .add_column(&mut schema.get_column_def::<Entity>(Column::Title))
78    ///     .take();
79    /// assert_eq!(
80    ///     alter_table.to_string(MysqlQueryBuilder::default()),
81    ///     "ALTER TABLE `posts` ADD COLUMN `title` varchar(255) NOT NULL"
82    /// );
83    /// ```
84    pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
85    where
86        E: EntityTrait,
87    {
88        column_def_from_entity_column::<E>(column, self.backend)
89    }
90}
91
92pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> TypeCreateStatement
93where
94    A: ActiveEnum,
95{
96    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
97        panic!("TypeCreateStatement is not supported in MySQL & SQLite");
98    }
99    let col_def = A::db_type();
100    let col_type = col_def.get_column_type();
101    create_enum_from_column_type(col_type)
102}
103
104pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> TypeCreateStatement {
105    let (name, values) = match col_type {
106        ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
107        _ => panic!("Should be ColumnType::Enum"),
108    };
109    Type::create().as_enum(name).values(values).to_owned()
110}
111
112#[allow(clippy::needless_borrow)]
113pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
114where
115    E: EntityTrait,
116{
117    if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
118        return Vec::new();
119    }
120    let mut vec = Vec::new();
121    for col in E::Column::iter() {
122        let col_def = col.def();
123        let col_type = col_def.get_column_type();
124        if !matches!(col_type, ColumnType::Enum { .. }) {
125            continue;
126        }
127        let stmt = create_enum_from_column_type(&col_type);
128        vec.push(stmt);
129    }
130    vec
131}
132
133pub(crate) fn create_index_from_entity<E>(
134    entity: E,
135    _backend: DbBackend,
136) -> Vec<IndexCreateStatement>
137where
138    E: EntityTrait,
139{
140    let mut vec = Vec::new();
141    for column in E::Column::iter() {
142        let column_def = column.def();
143        if !column_def.indexed {
144            continue;
145        }
146        let stmt = Index::create()
147            .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
148            .table(entity)
149            .col(column)
150            .to_owned();
151        vec.push(stmt)
152    }
153    vec
154}
155
156pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
157where
158    E: EntityTrait,
159{
160    let mut stmt = TableCreateStatement::new();
161
162    if let Some(comment) = entity.comment() {
163        stmt.comment(comment);
164    }
165
166    for column in E::Column::iter() {
167        let mut column_def = column_def_from_entity_column::<E>(column, backend);
168        stmt.col(&mut column_def);
169    }
170
171    if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
172        let mut idx_pk = Index::create();
173        for primary_key in E::PrimaryKey::iter() {
174            idx_pk.col(primary_key);
175        }
176        stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
177    }
178
179    for relation in E::Relation::iter() {
180        let relation = relation.def();
181        if relation.is_owner {
182            continue;
183        }
184        stmt.foreign_key(&mut relation.into());
185    }
186
187    stmt.table(entity.table_ref()).take()
188}
189
190fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
191where
192    E: EntityTrait,
193{
194    let orm_column_def = column.def();
195    let types = match &orm_column_def.col_type {
196        ColumnType::Enum { name, variants } => match backend {
197            DbBackend::MySql => {
198                let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
199                ColumnType::custom(format!("ENUM('{}')", variants.join("', '")).as_str())
200            }
201            DbBackend::Postgres => ColumnType::Custom(SeaRc::clone(name)),
202            DbBackend::Sqlite => orm_column_def.col_type,
203        },
204        _ => orm_column_def.col_type,
205    };
206    let mut column_def = ColumnDef::new_with_type(column, types);
207    if !orm_column_def.null {
208        column_def.not_null();
209    }
210    if orm_column_def.unique {
211        column_def.unique_key();
212    }
213    if let Some(default) = orm_column_def.default {
214        column_def.default(default);
215    }
216    if let Some(comment) = orm_column_def.comment {
217        column_def.comment(comment);
218    }
219    for primary_key in E::PrimaryKey::iter() {
220        if column.to_string() == primary_key.into_column().to_string() {
221            if E::PrimaryKey::auto_increment() {
222                column_def.auto_increment();
223            }
224            if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
225                column_def.primary_key();
226            }
227        }
228    }
229    column_def
230}
231
232#[cfg(test)]
233mod tests {
234    use crate::{sea_query::*, tests_cfg::*, DbBackend, EntityName, Schema};
235    use pretty_assertions::assert_eq;
236
237    #[test]
238    fn test_create_table_from_entity_table_ref() {
239        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
240            let schema = Schema::new(builder);
241            assert_eq!(
242                builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
243                builder.build(
244                    &get_cake_filling_price_stmt()
245                        .table(CakeFillingPrice.table_ref())
246                        .to_owned()
247                )
248            );
249        }
250    }
251
252    fn get_cake_filling_price_stmt() -> TableCreateStatement {
253        Table::create()
254            .col(
255                ColumnDef::new(cake_filling_price::Column::CakeId)
256                    .integer()
257                    .not_null(),
258            )
259            .col(
260                ColumnDef::new(cake_filling_price::Column::FillingId)
261                    .integer()
262                    .not_null(),
263            )
264            .col(
265                ColumnDef::new(cake_filling_price::Column::Price)
266                    .decimal()
267                    .not_null(),
268            )
269            .primary_key(
270                Index::create()
271                    .name("pk-cake_filling_price")
272                    .col(cake_filling_price::Column::CakeId)
273                    .col(cake_filling_price::Column::FillingId)
274                    .primary(),
275            )
276            .foreign_key(
277                ForeignKeyCreateStatement::new()
278                    .name("fk-cake_filling_price-cake_id-filling_id")
279                    .from_tbl(CakeFillingPrice)
280                    .from_col(cake_filling_price::Column::CakeId)
281                    .from_col(cake_filling_price::Column::FillingId)
282                    .to_tbl(CakeFilling)
283                    .to_col(cake_filling::Column::CakeId)
284                    .to_col(cake_filling::Column::FillingId),
285            )
286            .to_owned()
287    }
288
289    #[test]
290    fn test_create_index_from_entity_table_ref() {
291        for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
292            let schema = Schema::new(builder);
293
294            assert_eq!(
295                builder.build(&schema.create_table_from_entity(indexes::Entity)),
296                builder.build(
297                    &get_indexes_stmt()
298                        .table(indexes::Entity.table_ref())
299                        .to_owned()
300                )
301            );
302
303            let stmts = schema.create_index_from_entity(indexes::Entity);
304            assert_eq!(stmts.len(), 2);
305
306            let idx: IndexCreateStatement = Index::create()
307                .name("idx-indexes-index1_attr")
308                .table(indexes::Entity)
309                .col(indexes::Column::Index1Attr)
310                .to_owned();
311            assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
312
313            let idx: IndexCreateStatement = Index::create()
314                .name("idx-indexes-index2_attr")
315                .table(indexes::Entity)
316                .col(indexes::Column::Index2Attr)
317                .to_owned();
318            assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
319        }
320    }
321
322    fn get_indexes_stmt() -> TableCreateStatement {
323        Table::create()
324            .col(
325                ColumnDef::new(indexes::Column::IndexesId)
326                    .integer()
327                    .not_null()
328                    .auto_increment()
329                    .primary_key(),
330            )
331            .col(
332                ColumnDef::new(indexes::Column::UniqueAttr)
333                    .integer()
334                    .not_null()
335                    .unique_key(),
336            )
337            .col(
338                ColumnDef::new(indexes::Column::Index1Attr)
339                    .integer()
340                    .not_null(),
341            )
342            .col(
343                ColumnDef::new(indexes::Column::Index2Attr)
344                    .integer()
345                    .not_null()
346                    .unique_key(),
347            )
348            .to_owned()
349    }
350}