1use crate::{
2 ActiveEnum, ColumnTrait, ColumnType, DbBackend, EntityTrait, Iterable, PrimaryKeyArity,
3 PrimaryKeyToColumn, PrimaryKeyTrait, RelationTrait, Schema,
4};
5use sea_query::{
6 extension::postgres::{Type, TypeCreateStatement},
7 ColumnDef, Iden, Index, IndexCreateStatement, SeaRc, TableCreateStatement,
8};
9
10impl Schema {
11 pub fn create_enum_from_active_enum<A>(&self) -> TypeCreateStatement
13 where
14 A: ActiveEnum,
15 {
16 create_enum_from_active_enum::<A>(self.backend)
17 }
18
19 pub fn create_enum_from_entity<E>(&self, entity: E) -> Vec<TypeCreateStatement>
21 where
22 E: EntityTrait,
23 {
24 create_enum_from_entity(entity, self.backend)
25 }
26
27 pub fn create_table_from_entity<E>(&self, entity: E) -> TableCreateStatement
29 where
30 E: EntityTrait,
31 {
32 create_table_from_entity(entity, self.backend)
33 }
34
35 pub fn create_index_from_entity<E>(&self, entity: E) -> Vec<IndexCreateStatement>
38 where
39 E: EntityTrait,
40 {
41 create_index_from_entity(entity, self.backend)
42 }
43
44 pub fn get_column_def<E>(&self, column: E::Column) -> ColumnDef
85 where
86 E: EntityTrait,
87 {
88 column_def_from_entity_column::<E>(column, self.backend)
89 }
90}
91
92pub(crate) fn create_enum_from_active_enum<A>(backend: DbBackend) -> TypeCreateStatement
93where
94 A: ActiveEnum,
95{
96 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
97 panic!("TypeCreateStatement is not supported in MySQL & SQLite");
98 }
99 let col_def = A::db_type();
100 let col_type = col_def.get_column_type();
101 create_enum_from_column_type(col_type)
102}
103
104pub(crate) fn create_enum_from_column_type(col_type: &ColumnType) -> TypeCreateStatement {
105 let (name, values) = match col_type {
106 ColumnType::Enum { name, variants } => (name.clone(), variants.clone()),
107 _ => panic!("Should be ColumnType::Enum"),
108 };
109 Type::create().as_enum(name).values(values).to_owned()
110}
111
112#[allow(clippy::needless_borrow)]
113pub(crate) fn create_enum_from_entity<E>(_: E, backend: DbBackend) -> Vec<TypeCreateStatement>
114where
115 E: EntityTrait,
116{
117 if matches!(backend, DbBackend::MySql | DbBackend::Sqlite) {
118 return Vec::new();
119 }
120 let mut vec = Vec::new();
121 for col in E::Column::iter() {
122 let col_def = col.def();
123 let col_type = col_def.get_column_type();
124 if !matches!(col_type, ColumnType::Enum { .. }) {
125 continue;
126 }
127 let stmt = create_enum_from_column_type(&col_type);
128 vec.push(stmt);
129 }
130 vec
131}
132
133pub(crate) fn create_index_from_entity<E>(
134 entity: E,
135 _backend: DbBackend,
136) -> Vec<IndexCreateStatement>
137where
138 E: EntityTrait,
139{
140 let mut vec = Vec::new();
141 for column in E::Column::iter() {
142 let column_def = column.def();
143 if !column_def.indexed {
144 continue;
145 }
146 let stmt = Index::create()
147 .name(format!("idx-{}-{}", entity.to_string(), column.to_string()))
148 .table(entity)
149 .col(column)
150 .to_owned();
151 vec.push(stmt)
152 }
153 vec
154}
155
156pub(crate) fn create_table_from_entity<E>(entity: E, backend: DbBackend) -> TableCreateStatement
157where
158 E: EntityTrait,
159{
160 let mut stmt = TableCreateStatement::new();
161
162 if let Some(comment) = entity.comment() {
163 stmt.comment(comment);
164 }
165
166 for column in E::Column::iter() {
167 let mut column_def = column_def_from_entity_column::<E>(column, backend);
168 stmt.col(&mut column_def);
169 }
170
171 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY > 1 {
172 let mut idx_pk = Index::create();
173 for primary_key in E::PrimaryKey::iter() {
174 idx_pk.col(primary_key);
175 }
176 stmt.primary_key(idx_pk.name(format!("pk-{}", entity.to_string())).primary());
177 }
178
179 for relation in E::Relation::iter() {
180 let relation = relation.def();
181 if relation.is_owner {
182 continue;
183 }
184 stmt.foreign_key(&mut relation.into());
185 }
186
187 stmt.table(entity.table_ref()).take()
188}
189
190fn column_def_from_entity_column<E>(column: E::Column, backend: DbBackend) -> ColumnDef
191where
192 E: EntityTrait,
193{
194 let orm_column_def = column.def();
195 let types = match &orm_column_def.col_type {
196 ColumnType::Enum { name, variants } => match backend {
197 DbBackend::MySql => {
198 let variants: Vec<String> = variants.iter().map(|v| v.to_string()).collect();
199 ColumnType::custom(format!("ENUM('{}')", variants.join("', '")).as_str())
200 }
201 DbBackend::Postgres => ColumnType::Custom(SeaRc::clone(name)),
202 DbBackend::Sqlite => orm_column_def.col_type,
203 },
204 _ => orm_column_def.col_type,
205 };
206 let mut column_def = ColumnDef::new_with_type(column, types);
207 if !orm_column_def.null {
208 column_def.not_null();
209 }
210 if orm_column_def.unique {
211 column_def.unique_key();
212 }
213 if let Some(default) = orm_column_def.default {
214 column_def.default(default);
215 }
216 if let Some(comment) = orm_column_def.comment {
217 column_def.comment(comment);
218 }
219 for primary_key in E::PrimaryKey::iter() {
220 if column.to_string() == primary_key.into_column().to_string() {
221 if E::PrimaryKey::auto_increment() {
222 column_def.auto_increment();
223 }
224 if <<E::PrimaryKey as PrimaryKeyTrait>::ValueType as PrimaryKeyArity>::ARITY == 1 {
225 column_def.primary_key();
226 }
227 }
228 }
229 column_def
230}
231
232#[cfg(test)]
233mod tests {
234 use crate::{sea_query::*, tests_cfg::*, DbBackend, EntityName, Schema};
235 use pretty_assertions::assert_eq;
236
237 #[test]
238 fn test_create_table_from_entity_table_ref() {
239 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
240 let schema = Schema::new(builder);
241 assert_eq!(
242 builder.build(&schema.create_table_from_entity(CakeFillingPrice)),
243 builder.build(
244 &get_cake_filling_price_stmt()
245 .table(CakeFillingPrice.table_ref())
246 .to_owned()
247 )
248 );
249 }
250 }
251
252 fn get_cake_filling_price_stmt() -> TableCreateStatement {
253 Table::create()
254 .col(
255 ColumnDef::new(cake_filling_price::Column::CakeId)
256 .integer()
257 .not_null(),
258 )
259 .col(
260 ColumnDef::new(cake_filling_price::Column::FillingId)
261 .integer()
262 .not_null(),
263 )
264 .col(
265 ColumnDef::new(cake_filling_price::Column::Price)
266 .decimal()
267 .not_null(),
268 )
269 .primary_key(
270 Index::create()
271 .name("pk-cake_filling_price")
272 .col(cake_filling_price::Column::CakeId)
273 .col(cake_filling_price::Column::FillingId)
274 .primary(),
275 )
276 .foreign_key(
277 ForeignKeyCreateStatement::new()
278 .name("fk-cake_filling_price-cake_id-filling_id")
279 .from_tbl(CakeFillingPrice)
280 .from_col(cake_filling_price::Column::CakeId)
281 .from_col(cake_filling_price::Column::FillingId)
282 .to_tbl(CakeFilling)
283 .to_col(cake_filling::Column::CakeId)
284 .to_col(cake_filling::Column::FillingId),
285 )
286 .to_owned()
287 }
288
289 #[test]
290 fn test_create_index_from_entity_table_ref() {
291 for builder in [DbBackend::MySql, DbBackend::Postgres, DbBackend::Sqlite] {
292 let schema = Schema::new(builder);
293
294 assert_eq!(
295 builder.build(&schema.create_table_from_entity(indexes::Entity)),
296 builder.build(
297 &get_indexes_stmt()
298 .table(indexes::Entity.table_ref())
299 .to_owned()
300 )
301 );
302
303 let stmts = schema.create_index_from_entity(indexes::Entity);
304 assert_eq!(stmts.len(), 2);
305
306 let idx: IndexCreateStatement = Index::create()
307 .name("idx-indexes-index1_attr")
308 .table(indexes::Entity)
309 .col(indexes::Column::Index1Attr)
310 .to_owned();
311 assert_eq!(builder.build(&stmts[0]), builder.build(&idx));
312
313 let idx: IndexCreateStatement = Index::create()
314 .name("idx-indexes-index2_attr")
315 .table(indexes::Entity)
316 .col(indexes::Column::Index2Attr)
317 .to_owned();
318 assert_eq!(builder.build(&stmts[1]), builder.build(&idx));
319 }
320 }
321
322 fn get_indexes_stmt() -> TableCreateStatement {
323 Table::create()
324 .col(
325 ColumnDef::new(indexes::Column::IndexesId)
326 .integer()
327 .not_null()
328 .auto_increment()
329 .primary_key(),
330 )
331 .col(
332 ColumnDef::new(indexes::Column::UniqueAttr)
333 .integer()
334 .not_null()
335 .unique_key(),
336 )
337 .col(
338 ColumnDef::new(indexes::Column::Index1Attr)
339 .integer()
340 .not_null(),
341 )
342 .col(
343 ColumnDef::new(indexes::Column::Index2Attr)
344 .integer()
345 .not_null()
346 .unique_key(),
347 )
348 .to_owned()
349 }
350}