1use crate::{
2 ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, IdenStatic, Iterable,
3 PrimaryKeyArity, PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6 ColumnDef, DynIden, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement, TableName,
7 TableRef,
8 extension::postgres::{Type, TypeCreateStatement},
9};
10use std::collections::BTreeMap;
11
12impl Schema {
13 pub fn create_enum_from_active_enum<A>(&self) -> Option<TypeCreateStatement>
16 where
17 A: ActiveEnum,
18 {
19 create_enum_from_active_enum::<A>(self.backend)
20 }
21
22 pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
25 where
26 E: EntityTrait,
27 {
28 create_enum_from_entity(entity, self.backend)
29 }
30
31 pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
33 where
34 E: EntityTrait,
35 {
36 create_table_from_entity(entity, self.backend)
37 }
38
39 #[doc(hidden)]
40 pub fn create_table_with_index_from_entity<E>(&self, entity: E) -> TableCreateStatement
41 where
42 E: EntityTrait,
43 {
44 let mut table = create_table_from_entity(entity, self.backend);
45 for mut index in create_index_from_entity(entity, self.backend) {
46 table.index(&mut index);
47 }
48 table
49 }
50
51 pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
54 where
55 E: EntityTrait,
56 {
57 create_index_from_entity(entity, self.backend)
58 }
59
60 pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
94 where
95 E: EntityTrait,
96 {
97 column_def_from_entity_column::<E>(column, self.backend)
98 }
99}
100
101pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> Option<TypeCreateStatement>
102where
103 A: ActiveEnum,
104{
105 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
106 return None;
107 }
108 let col_def = A::db_type();
109 let col_type = col_def.get_column_type();
110 create_enum_from_column_type(col_type)
111}
112
113pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> Option<TypeCreateStatement> {
114 let (name, values) = match col_type {
115 ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
116 _ => return None,
117 };
118 Some(Type::create().as_enum(name).values(values).to_owned())
119}
120
121#[allow(clippy::needless_borrow)]
122pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
123where
124 E: EntityTrait,
125{
126 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
127 return Vec::new();
128 }
129 let mut vec = Vec::new();
130 for col in E::Column::iter() {
131 let col_def = col.def();
132 let col_type = col_def.get_column_type();
133 if !matches!(col_type, ColumnType::Enum { .. }) {
134 continue;
135 }
136 if let Some(stmt) = create_enum_from_column_type(&col_type) {
137 vec.push(stmt);
138 }
139 }
140 vec
141}
142
143pub(crate) fn create_index_from_entity<E>(
144 entity: E,
145 backend: DbBackend,
146) -> Vec<IndexCreateStatement>
147where
148 E: EntityTrait,
149{
150 let mut indexes = Vec::new();
151 let mut unique_keys: BTreeMap<String, Vec<DynIden>> = Default::default();
152
153 for column in E::Column::iter() {
154 let column_def = column.def();
155
156 if column_def.indexed && !column_def.unique {
157 let stmt = Index::create()
158 .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
159 .table(index_table_ref(entity.table_ref(), backend))
160 .col(column)
161 .take();
162 indexes.push(stmt);
163 }
164
165 if let Some(key) = column_def.unique_key {
166 unique_keys.entry(key).or_default().push(SeaRc::new(column));
167 }
168 }
169
170 for (key, cols) in unique_keys {
171 let mut stmt = Index::create()
172 .name(format!("idx-{}-{}", entity.to_string(), key))
173 .table(index_table_ref(entity.table_ref(), backend))
174 .unique()
175 .take();
176 for col in cols {
177 stmt.col(col);
178 }
179 indexes.push(stmt);
180 }
181
182 indexes
183}
184
185pub(crate) fn index_table_ref(table_ref: TableRef, backend: DbBackend) -> TableRef {
193 match backend {
194 DbBackend::Postgres => table_ref,
195 DbBackend::MySql | DbBackend::Sqlite => match table_ref {
196 TableRef::Table(TableName(Some(_), table), alias) => {
197 TableRef::Table(TableName(None, table), alias)
198 }
199 other => other,
200 },
201 }
202}
203
204pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
205where
206 E: EntityTrait,
207{
208 let mut stmt = TableCreateStatement::new();
209
210 if let Some(comment) = entity.comment() {
211 stmt.comment(comment);
212 }
213
214 for column in E::Column::iter() {
215 let mut column_def = column_def_from_entity_column::<E>(column, backend);
216 stmt.col(&mut column_def);
217 }
218
219 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
220 let mut idx_pk = Index::create();
221 for primary_key in E::PrimaryKey::iter() {
222 idx_pk.col(primary_key);
223 }
224 stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
225 }
226
227 for relation in E::Relation::iter() {
228 let relation = relation.def();
229 if relation.is_owner || relation.skip_fk {
230 continue;
231 }
232 stmt.foreign_key(&mut relation.into());
233 }
234
235 stmt.table(entity.table_ref()).take()
236}
237
238fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
239where
240 E: EntityTrait,
241{
242 let orm_column_def = column.def();
243 let types = match &orm_column_def.col_type {
244 ColumnType::Enum { name, variants } => match backend {
245 DbBackend::MySql => {
246 let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
247 ColumnType::custom(format!("ENUM('{}')", variants.join("', '")))
248 }
249 DbBackend::Postgres => ColumnType::Custom(name.clone()),
250 DbBackend::Sqlite => orm_column_def.col_type,
251 },
252 _ => orm_column_def.col_type,
253 };
254 let mut column_def = ColumnDef::new_with_type(column, types);
255 if !orm_column_def.null {
256 column_def.not_null();
257 }
258 if orm_column_def.unique {
259 column_def.unique_key();
260 }
261 if let Some(default) = orm_column_def.default {
262 column_def.default(default);
263 }
264 if let Some(comment) = &orm_column_def.comment {
265 column_def.comment(comment);
266 }
267 if let Some(extra) = &orm_column_def.extra {
268 column_def.extra(extra);
269 }
270 match (&orm_column_def.renamed_from, &orm_column_def.comment) {
271 (Some(renamed_from), Some(comment)) => {
272 column_def.comment(format!("{comment}; renamed_from \"{renamed_from}\""));
273 }
274 (Some(renamed_from), None) => {
275 column_def.comment(format!("renamed_from \"{renamed_from}\""));
276 }
277 (None, _) => {}
278 }
279 for primary_key in E::PrimaryKey::iter() {
280 if column.as_str() == primary_key.into_column().as_str() {
281 if E::PrimaryKey::auto_increment() {
282 column_def.auto_increment();
283 }
284 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
285 column_def.primary_key();
286 }
287 }
288 }
289 column_def
290}
291
292#[cfg(test)]
293mod tests {
294 use crate::{DbBackend, EntityName, Schema, sea_query::*, tests_cfg::*};
295 use pretty_assertions::assert_eq;
296
297 mod custom_schema_indexes {
298 use crate as sea_orm;
299 use crate::entity::prelude::*;
300
301 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
302 #[sea_orm(schema_name = "sys", table_name = "app_user")]
303 pub struct Model {
304 #[sea_orm(primary_key)]
305 pub id: i32,
306 #[sea_orm(indexed)]
307 pub email: String,
308 #[sea_orm(unique_key = "tenant_name")]
309 pub tenant_id: i32,
310 #[sea_orm(unique_key = "tenant_name")]
311 pub name: String,
312 }
313
314 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
315 pub enum Relation {}
316
317 impl ActiveModelBehavior for ActiveModel {}
318 }
319
320 #[test]
321 fn test_create_table_from_entity_table_ref() {
322 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
323 let schema = Schema::new(builder);
324 assert_eq!(
325 builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
326 builder.build(
327 &get_cake_filling_price_stmt()
328 .table(CakeFillingPrice.table_ref())
329 .to_owned()
330 )
331 );
332 }
333 }
334
335 fn get_cake_filling_price_stmt() -> TableCreateStatement {
336 Table::create()
337 .col(
338 ColumnDef::new(cake_filling_price::Column::CakeId)
339 .integer()
340 .not_null(),
341 )
342 .col(
343 ColumnDef::new(cake_filling_price::Column::FillingId)
344 .integer()
345 .not_null(),
346 )
347 .col(
348 ColumnDef::new(cake_filling_price::Column::Price)
349 .decimal()
350 .not_null()
351 .extra("CHECK (price > 0)"),
352 )
353 .primary_key(
354 Index::create()
355 .name("pk-cake_filling_price")
356 .col(cake_filling_price::Column::CakeId)
357 .col(cake_filling_price::Column::FillingId)
358 .primary(),
359 )
360 .foreign_key(
361 ForeignKeyCreateStatement::new()
362 .name("fk-cake_filling_price-cake_id-filling_id")
363 .from_tbl(CakeFillingPrice)
364 .from_col(cake_filling_price::Column::CakeId)
365 .from_col(cake_filling_price::Column::FillingId)
366 .to_tbl(CakeFilling)
367 .to_col(cake_filling::Column::CakeId)
368 .to_col(cake_filling::Column::FillingId),
369 )
370 .to_owned()
371 }
372
373 #[test]
374 fn test_create_index_from_entity_table_ref() {
375 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
376 let schema = Schema::new(builder);
377
378 assert_eq!(
379 builder.build(&schema.create_table_from_entity(indexes::Entity)),
380 builder.build(
381 &get_indexes_table_stmt()
382 .table(indexes::Entity.table_ref())
383 .to_owned()
384 )
385 );
386
387 let stmts = schema.create_index_from_entity(indexes::Entity);
388 assert_eq!(stmts.len(), 2);
389
390 let index_table = match builder {
391 DbBackend::Postgres => indexes::Entity.table_ref(),
392 DbBackend::MySql | DbBackend::Sqlite => indexes::Entity.into_table_ref(),
393 };
394 let idx: IndexCreateStatement = Index::create()
395 .name("idx-indexes-index1_attr")
396 .table(index_table)
397 .col(indexes::Column::Index1Attr)
398 .to_owned();
399 assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
400
401 let index_table = match builder {
402 DbBackend::Postgres => indexes::Entity.table_ref(),
403 DbBackend::MySql | DbBackend::Sqlite => indexes::Entity.into_table_ref(),
404 };
405 let idx: IndexCreateStatement = Index::create()
406 .name("idx-indexes-my_unique")
407 .table(index_table)
408 .col(indexes::Column::UniqueKeyA)
409 .col(indexes::Column::UniqueKeyB)
410 .unique()
411 .take();
412 assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
413 }
414 }
415
416 #[test]
417 fn test_create_index_from_entity_non_default_schema_table_ref() {
418 let builder = DbBackend::Postgres;
419 let schema = Schema::new(builder);
420 let stmts = schema.create_index_from_entity(custom_schema_indexes::Entity);
421 assert_eq!(stmts.len(), 2);
422
423 let idx: IndexCreateStatement = Index::create()
424 .name("idx-app_user-email")
425 .table(custom_schema_indexes::Entity.table_ref())
426 .col(custom_schema_indexes::Column::Email)
427 .to_owned();
428 assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
429
430 let idx: IndexCreateStatement = Index::create()
431 .name("idx-app_user-tenant_name")
432 .table(custom_schema_indexes::Entity.table_ref())
433 .col(custom_schema_indexes::Column::TenantId)
434 .col(custom_schema_indexes::Column::Name)
435 .unique()
436 .take();
437 assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
438
439 assert!(builder.build(&stmts[0]).sql.contains(r#""sys"."app_user""#));
441 }
442
443 #[test]
448 fn test_create_index_from_entity_non_default_schema_strips_schema_on_mysql_sqlite() {
449 for builder in [DbBackend::MySql, DbBackend::Sqlite] {
450 let schema = Schema::new(builder);
451 let stmts = schema.create_index_from_entity(custom_schema_indexes::Entity);
453 assert_eq!(stmts.len(), 2);
454
455 for stmt in &stmts {
456 let sql = builder.build(stmt).sql;
457 assert!(
458 sql.contains("app_user"),
459 "{builder:?} index should target the table: {sql}"
460 );
461 assert!(
462 !sql.contains("sys"),
463 "{builder:?} index should not be schema-qualified: {sql}"
464 );
465 }
466 }
467 }
468
469 fn get_indexes_table_stmt() -> TableCreateStatement {
470 Table::create()
471 .col(
472 ColumnDef::new(indexes::Column::IndexesId)
473 .integer()
474 .not_null()
475 .auto_increment()
476 .primary_key(),
477 )
478 .col(
479 ColumnDef::new(indexes::Column::UniqueAttr)
480 .integer()
481 .not_null()
482 .unique_key(),
483 )
484 .col(
485 ColumnDef::new(indexes::Column::Index1Attr)
486 .integer()
487 .not_null(),
488 )
489 .col(
490 ColumnDef::new(indexes::Column::Index2Attr)
491 .integer()
492 .not_null()
493 .unique_key(),
494 )
495 .col(
496 ColumnDef::new(indexes::Column::UniqueKeyA)
497 .string()
498 .not_null(),
499 )
500 .col(
501 ColumnDef::new(indexes::Column::UniqueKeyB)
502 .string()
503 .not_null(),
504 )
505 .to_owned()
506 }
507}