1use crate::{
2 ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, IdenStatic, Iterable,
3 PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6 ColumnDef, DynIden, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement,
7 extension::postgres::{Type, TypeCreateStatement},
8};
9use std::collections::BTreeMap;
10
11impl Schema {
12 pub fn create_enum_from_active_enum<A>(&self) -> Option<TypeCreateStatement>
15 where
16 A: ActiveEnum,
17 {
18 create_enum_from_active_enum::<A>(self.backend)
19 }
20
21 pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
24 where
25 E: EntityTrait,
26 {
27 create_enum_from_entity(entity, self.backend)
28 }
29
30 pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
32 where
33 E: EntityTrait,
34 {
35 create_table_from_entity(entity, self.backend)
36 }
37
38 #[doc(hidden)]
39 pub fn create_table_with_index_from_entity<E>(&self, entity: E) -> TableCreateStatement
40 where
41 E: EntityTrait,
42 {
43 let mut table = create_table_from_entity(entity, self.backend);
44 for mut index in create_index_from_entity(entity, self.backend) {
45 table.index(&mut index);
46 }
47 table
48 }
49
50 pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
53 where
54 E: EntityTrait,
55 {
56 create_index_from_entity(entity, self.backend)
57 }
58
59 pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
93 where
94 E: EntityTrait,
95 {
96 column_def_from_entity_column::<E>(column, self.backend)
97 }
98}
99
100pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> Option<TypeCreateStatement>
101where
102 A: ActiveEnum,
103{
104 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
105 return None;
106 }
107 let col_def = A::db_type();
108 let col_type = col_def.get_column_type();
109 create_enum_from_column_type(col_type)
110}
111
112pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> Option<TypeCreateStatement> {
113 let (name, values) = match col_type {
114 ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
115 _ => return None,
116 };
117 Some(Type::create().as_enum(name).values(values).to_owned())
118}
119
120#[allow(clippy::needless_borrow)]
121pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
122where
123 E: EntityTrait,
124{
125 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
126 return Vec::new();
127 }
128 let mut vec = Vec::new();
129 for col in E::Column::iter() {
130 let col_def = col.def();
131 let col_type = col_def.get_column_type();
132 if !matches!(col_type, ColumnType::Enum { .. }) {
133 continue;
134 }
135 if let Some(stmt) = create_enum_from_column_type(&col_type) {
136 vec.push(stmt);
137 }
138 }
139 vec
140}
141
142pub(crate) fn create_index_from_entity<E>(
143 entity: E,
144 _backend: DbBackend,
145) -> Vec<IndexCreateStatement>
146where
147 E: EntityTrait,
148{
149 let mut indexes = Vec::new();
150 let mut unique_keys: BTreeMap<String, Vec<DynIden>> = Default::default();
151
152 for column in E::Column::iter() {
153 let column_def = column.def();
154
155 if column_def.indexed && !column_def.unique {
156 let stmt = Index::create()
157 .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
158 .table(entity)
159 .col(column)
160 .take();
161 indexes.push(stmt);
162 }
163
164 if let Some(key) = column_def.unique_key {
165 unique_keys.entry(key).or_default().push(SeaRc::new(column));
166 }
167 }
168
169 for (key, cols) in unique_keys {
170 let mut stmt = Index::create()
171 .name(format!("idx-{}-{}", entity.to_string(), key))
172 .table(entity)
173 .unique()
174 .take();
175 for col in cols {
176 stmt.col(col);
177 }
178 indexes.push(stmt);
179 }
180
181 indexes
182}
183
184pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
185where
186 E: EntityTrait,
187{
188 let mut stmt = TableCreateStatement::new();
189
190 if let Some(comment) = entity.comment() {
191 stmt.comment(comment);
192 }
193
194 for column in E::Column::iter() {
195 let mut column_def = column_def_from_entity_column::<E>(column, backend);
196 stmt.col(&mut column_def);
197 }
198
199 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
200 let mut idx_pk = Index::create();
201 for primary_key in E::PrimaryKey::iter() {
202 idx_pk.col(primary_key);
203 }
204 stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
205 }
206
207 for relation in E::Relation::iter() {
208 let relation = relation.def();
209 if relation.is_owner || relation.skip_fk {
210 continue;
211 }
212 stmt.foreign_key(&mut relation.into());
213 }
214
215 stmt.table(entity.table_ref()).take()
216}
217
218fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
219where
220 E: EntityTrait,
221{
222 let orm_column_def = column.def();
223 let types = match &orm_column_def.col_type {
224 ColumnType::Enum { name, variants } => match backend {
225 DbBackend::MySql => {
226 let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
227 ColumnType::custom(format!("ENUM('{}')", variants.join("', '")))
228 }
229 DbBackend::Postgres => ColumnType::Custom(name.clone()),
230 DbBackend::Sqlite => orm_column_def.col_type,
231 },
232 _ => orm_column_def.col_type,
233 };
234 let mut column_def = ColumnDef::new_with_type(column, types);
235 if !orm_column_def.null {
236 column_def.not_null();
237 }
238 if orm_column_def.unique {
239 column_def.unique_key();
240 }
241 if let Some(default) = orm_column_def.default {
242 column_def.default(default);
243 }
244 if let Some(comment) = &orm_column_def.comment {
245 column_def.comment(comment);
246 }
247 if let Some(extra) = &orm_column_def.extra {
248 column_def.extra(extra);
249 }
250 match (&orm_column_def.renamed_from, &orm_column_def.comment) {
251 (Some(renamed_from), Some(comment)) => {
252 column_def.comment(format!("{comment}; renamed_from \"{renamed_from}\""));
253 }
254 (Some(renamed_from), None) => {
255 column_def.comment(format!("renamed_from \"{renamed_from}\""));
256 }
257 (None, _) => {}
258 }
259 for primary_key in E::PrimaryKey::iter() {
260 if column.as_str() == primary_key.into_column().as_str() {
261 if E::PrimaryKey::auto_increment() {
262 column_def.auto_increment();
263 }
264 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
265 column_def.primary_key();
266 }
267 }
268 }
269 column_def
270}
271
272#[cfg(test)]
273mod tests {
274 use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*};
275 use pretty_assertions::assert_eq;
276
277 #[test]
278 fn test_create_table_from_entity_table_ref() {
279 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
280 let schema = Schema::new(builder);
281 assert_eq!(
282 builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
283 builder.build(
284 &get_cake_filling_price_stmt()
285 .table(CakeFillingPrice.table_ref())
286 .to_owned()
287 )
288 );
289 }
290 }
291
292 fn get_cake_filling_price_stmt() -> TableCreateStatement {
293 Table::create()
294 .col(
295 ColumnDef::new(cake_filling_price::Column::CakeId)
296 .integer()
297 .not_null(),
298 )
299 .col(
300 ColumnDef::new(cake_filling_price::Column::FillingId)
301 .integer()
302 .not_null(),
303 )
304 .col(
305 ColumnDef::new(cake_filling_price::Column::Price)
306 .decimal()
307 .not_null()
308 .extra("CHECK (price > 0)"),
309 )
310 .primary_key(
311 Index::create()
312 .name("pk-cake_filling_price")
313 .col(cake_filling_price::Column::CakeId)
314 .col(cake_filling_price::Column::FillingId)
315 .primary(),
316 )
317 .foreign_key(
318 ForeignKeyCreateStatement::new()
319 .name("fk-cake_filling_price-cake_id-filling_id")
320 .from_tbl(CakeFillingPrice)
321 .from_col(cake_filling_price::Column::CakeId)
322 .from_col(cake_filling_price::Column::FillingId)
323 .to_tbl(CakeFilling)
324 .to_col(cake_filling::Column::CakeId)
325 .to_col(cake_filling::Column::FillingId),
326 )
327 .to_owned()
328 }
329
330 #[test]
331 fn test_create_index_from_entity_table_ref() {
332 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
333 let schema = Schema::new(builder);
334
335 assert_eq!(
336 builder.build(&schema.create_table_from_entity(indexes::Entity)),
337 builder.build(
338 &get_indexes_table_stmt()
339 .table(indexes::Entity.table_ref())
340 .to_owned()
341 )
342 );
343
344 let stmts = schema.create_index_from_entity(indexes::Entity);
345 assert_eq!(stmts.len(), 2);
346
347 let idx: IndexCreateStatement = Index::create()
348 .name("idx-indexes-index1_attr")
349 .table(indexes::Entity)
350 .col(indexes::Column::Index1Attr)
351 .to_owned();
352 assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
353
354 let idx: IndexCreateStatement = Index::create()
355 .name("idx-indexes-my_unique")
356 .table(indexes::Entity)
357 .col(indexes::Column::UniqueKeyA)
358 .col(indexes::Column::UniqueKeyB)
359 .unique()
360 .take();
361 assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
362 }
363 }
364
365 fn get_indexes_table_stmt() -> TableCreateStatement {
366 Table::create()
367 .col(
368 ColumnDef::new(indexes::Column::IndexesId)
369 .integer()
370 .not_null()
371 .auto_increment()
372 .primary_key(),
373 )
374 .col(
375 ColumnDef::new(indexes::Column::UniqueAttr)
376 .integer()
377 .not_null()
378 .unique_key(),
379 )
380 .col(
381 ColumnDef::new(indexes::Column::Index1Attr)
382 .integer()
383 .not_null(),
384 )
385 .col(
386 ColumnDef::new(indexes::Column::Index2Attr)
387 .integer()
388 .not_null()
389 .unique_key(),
390 )
391 .col(
392 ColumnDef::new(indexes::Column::UniqueKeyA)
393 .string()
394 .not_null(),
395 )
396 .col(
397 ColumnDef::new(indexes::Column::UniqueKeyB)
398 .string()
399 .not_null(),
400 )
401 .to_owned()
402 }
403}