1use crate::{
2 ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3 RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap, HashSet};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12 pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13 let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14 let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15 let mut entities = BTreeMap::new();
16 for table_create in table_create_stmts.into_iter() {
17 let table_name = match table_create.get_table_name() {
18 Some(table_ref) => table_ref.sea_orm_table().to_string(),
19 None => {
20 return Err(Error::TransformError(
21 "Table name should not be empty".into(),
22 ));
23 }
24 };
25 let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26 let mut columns: Vec<Column> = table_create
27 .get_columns()
28 .iter()
29 .map(|col_def| {
30 let primary_key = col_def.get_column_spec().primary_key;
31 if primary_key {
32 primary_keys.push(PrimaryKey {
33 name: col_def.get_column_name(),
34 });
35 }
36 col_def.into()
37 })
38 .map(|mut col: Column| {
39 col.unique = table_create
40 .get_indexes()
41 .iter()
42 .filter(|index| index.is_unique_key())
43 .map(|index| index.get_index_spec().get_column_names())
44 .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45 .count()
46 > 0;
47 col
48 })
49 .inspect(|col| {
50 if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51 {
52 enums.insert(
53 name.to_string(),
54 ActiveEnum {
55 enum_name: name.clone(),
56 values: variants.clone(),
57 },
58 );
59 }
60 })
61 .collect();
62 for index in table_create.get_indexes().iter() {
63 if index.is_unique_key() {
64 let col_names = index.get_index_spec().get_column_names();
65 if col_names.len() > 1 {
66 if let Some(mut key_name) = index.get_index_spec().get_name() {
67 if let Some((_, suffix)) = key_name.rsplit_once('-') {
68 key_name = suffix;
69 }
70 for col_name in col_names {
71 for column in columns.iter_mut() {
72 if column.name == col_name {
73 column.unique_key = Some(key_name.to_owned());
74 }
75 }
76 }
77 }
78 }
79 }
80 }
81 let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
82 let relations: Vec<Relation> = table_create
83 .get_foreign_key_create_stmts()
84 .iter()
85 .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
86 .map(|tbl_fk| {
87 let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
88 if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
89 if *count == 0 {
90 *count = 1;
91 }
92 *count += 1;
93 } else {
94 ref_table_counts.insert(ref_tbl, 0);
95 };
96 tbl_fk.into()
97 })
98 .collect::<Vec<_>>()
99 .into_iter()
100 .rev()
101 .map(|mut rel: Relation| {
102 rel.self_referencing = rel.ref_table == table_name;
103 if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
104 rel.num_suffix = *count;
105 if *count > 0 {
106 *count -= 1;
107 }
108 }
109 rel
110 })
111 .rev()
112 .collect();
113 primary_keys.extend(
114 table_create
115 .get_indexes()
116 .iter()
117 .filter(|index| index.is_primary_key())
118 .flat_map(|index| {
119 index
120 .get_index_spec()
121 .get_column_names()
122 .into_iter()
123 .map(|name| PrimaryKey { name })
124 .collect::<Vec<_>>()
125 }),
126 );
127 let entity = Entity {
128 table_name: table_name.clone(),
129 columns,
130 relations: relations.clone(),
131 conjunct_relations: vec![],
132 primary_keys,
133 };
134 entities.insert(table_name.clone(), entity.clone());
135 for mut rel in relations.into_iter() {
136 if rel.self_referencing {
138 continue;
139 }
140 if rel.num_suffix > 0 {
143 continue;
144 }
145 let ref_table = rel.ref_table;
146 let mut unique = true;
147 for column in rel.columns.iter() {
148 if !entity
149 .columns
150 .iter()
151 .filter(|col| col.unique)
152 .any(|col| col.name.as_str() == column)
153 {
154 unique = false;
155 break;
156 }
157 }
158 if rel.columns.len() == entity.primary_keys.len() {
159 let mut count_pk = 0;
160 for primary_key in entity.primary_keys.iter() {
161 if rel.columns.contains(&primary_key.name) {
162 count_pk += 1;
163 }
164 }
165 if count_pk == entity.primary_keys.len() {
166 unique = true;
167 }
168 }
169 let rel_type = if unique {
170 RelationType::HasOne
171 } else {
172 RelationType::HasMany
173 };
174 rel.rel_type = rel_type;
175 rel.ref_table = table_name.to_string();
176 rel.columns = Vec::new();
177 rel.ref_columns = Vec::new();
178 if let Some(vec) = inverse_relations.get_mut(&ref_table) {
179 vec.push(rel);
180 } else {
181 inverse_relations.insert(ref_table, vec![rel]);
182 }
183 }
184 }
185 for (tbl_name, relations) in inverse_relations.into_iter() {
186 if let Some(entity) = entities.get_mut(&tbl_name) {
187 for relation in relations.into_iter() {
188 let duplicate_relation = entity
189 .relations
190 .iter()
191 .any(|rel| rel.ref_table == relation.ref_table);
192 if !duplicate_relation {
193 entity.relations.push(relation);
194 }
195 }
196 }
197 }
198
199 let table_names: HashSet<String> = entities.keys().cloned().collect();
203 for entity in entities.values_mut() {
204 entity
205 .relations
206 .retain(|rel| rel.self_referencing || table_names.contains(&rel.ref_table));
207 }
208
209 for table_name in entities.clone().keys() {
210 let relations = match entities.get(table_name) {
211 Some(entity) => {
212 let is_conjunct_relation =
213 entity.relations.len() == 2 && entity.primary_keys.len() == 2;
214 if !is_conjunct_relation {
215 continue;
216 }
217 entity.relations.clone()
218 }
219 None => unreachable!(),
220 };
221 for (i, rel) in relations.iter().enumerate() {
222 let another_rel = relations.get((i == 0) as usize).unwrap();
223 if let Some(entity) = entities.get_mut(&rel.ref_table) {
224 let conjunct_relation = ConjunctRelation {
225 via: table_name.clone(),
226 to: another_rel.ref_table.clone(),
227 };
228 entity.conjunct_relations.push(conjunct_relation);
229 }
230 }
231 }
232 Ok(EntityWriter {
233 entities: entities
234 .into_values()
235 .map(|mut v| {
236 let duplicated_to: Vec<_> = v
238 .conjunct_relations
239 .iter()
240 .fold(HashMap::new(), |mut acc, conjunct_relation| {
241 acc.entry(conjunct_relation.to.clone())
242 .and_modify(|c| *c += 1)
243 .or_insert(1);
244 acc
245 })
246 .into_iter()
247 .filter(|(_, v)| v > &1)
248 .map(|(k, _)| k)
249 .collect();
250 v.conjunct_relations
251 .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
252
253 v.relations.iter_mut().for_each(|relation| {
256 if v.conjunct_relations
257 .iter()
258 .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
259 {
260 relation.impl_related = false;
261 }
262 });
263
264 v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
266 v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
267 v
268 })
269 .collect(),
270 enums,
271 })
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use pretty_assertions::assert_eq;
279 use proc_macro2::TokenStream;
280 use sea_orm::{DbBackend, Schema};
281 use sea_query::{ColumnDef, ForeignKey, Table};
282 use std::{
283 error::Error,
284 io::{self, BufRead, BufReader},
285 };
286
287 #[test]
288 fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
289 use crate::tests_cfg::duplicated_many_to_many_paths::*;
290 let schema = Schema::new(DbBackend::Postgres);
291
292 validate_compact_entities(
293 vec![
294 schema.create_table_from_entity(bills::Entity),
295 schema.create_table_from_entity(users::Entity),
296 schema.create_table_from_entity(users_saved_bills::Entity),
297 schema.create_table_from_entity(users_votes::Entity),
298 ],
299 vec![
300 (
301 "bills",
302 include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
303 ),
304 (
305 "users",
306 include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
307 ),
308 (
309 "users_saved_bills",
310 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
311 ),
312 (
313 "users_votes",
314 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
315 ),
316 ],
317 )
318 }
319
320 #[test]
321 fn many_to_many() -> Result<(), Box<dyn Error>> {
322 use crate::tests_cfg::many_to_many::*;
323 let schema = Schema::new(DbBackend::Postgres);
324
325 validate_compact_entities(
326 vec![
327 schema.create_table_from_entity(bills::Entity),
328 schema.create_table_from_entity(users::Entity),
329 schema.create_table_from_entity(users_votes::Entity),
330 ],
331 vec![
332 ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
333 ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
334 (
335 "users_votes",
336 include_str!("../tests_cfg/many_to_many/users_votes.rs"),
337 ),
338 ],
339 )
340 }
341
342 #[test]
343 fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
344 use crate::tests_cfg::many_to_many_multiple::*;
345 let schema = Schema::new(DbBackend::Postgres);
346
347 validate_compact_entities(
348 vec![
349 schema.create_table_from_entity(bills::Entity),
350 schema.create_table_from_entity(users::Entity),
351 schema.create_table_from_entity(users_votes::Entity),
352 ],
353 vec![
354 (
355 "bills",
356 include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
357 ),
358 (
359 "users",
360 include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
361 ),
362 (
363 "users_votes",
364 include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
365 ),
366 ],
367 )
368 }
369
370 #[test]
371 fn self_referencing() -> Result<(), Box<dyn Error>> {
372 use crate::tests_cfg::self_referencing::*;
373 let schema = Schema::new(DbBackend::Postgres);
374
375 validate_compact_entities(
376 vec![
377 schema.create_table_from_entity(bills::Entity),
378 schema.create_table_from_entity(users::Entity),
379 ],
380 vec![
381 (
382 "bills",
383 include_str!("../tests_cfg/self_referencing/bills.rs"),
384 ),
385 (
386 "users",
387 include_str!("../tests_cfg/self_referencing/users.rs"),
388 ),
389 ],
390 )
391 }
392
393 #[test]
394 fn test_indexes_transform() -> Result<(), Box<dyn Error>> {
395 let schema = Schema::new(DbBackend::Postgres);
396
397 validate_compact_entities(
398 vec![
399 schema.create_table_with_index_from_entity(
400 crate::tests_cfg::compact::indexes::Entity,
401 ),
402 ],
403 vec![("indexes", include_str!("../tests_cfg/compact/indexes.rs"))],
404 )?;
405
406 validate_dense_entities(
407 vec![
408 schema
409 .create_table_with_index_from_entity(crate::tests_cfg::dense::indexes::Entity),
410 ],
411 vec![("indexes", include_str!("../tests_cfg/dense/indexes.rs"))],
412 )?;
413
414 Ok(())
415 }
416
417 #[test]
418 fn filter_relations_to_missing_entities() -> Result<(), Box<dyn Error>> {
419 let parent_stmt = || {
420 Table::create()
421 .table("parent")
422 .col(
423 ColumnDef::new("id")
424 .integer()
425 .not_null()
426 .auto_increment()
427 .primary_key(),
428 )
429 .to_owned()
430 };
431
432 let child_stmt = || {
433 Table::create()
434 .table("child")
435 .col(
436 ColumnDef::new("id")
437 .integer()
438 .not_null()
439 .auto_increment()
440 .primary_key(),
441 )
442 .col(ColumnDef::new("parent_id").integer().not_null())
443 .foreign_key(
444 ForeignKey::create()
445 .name("fk-child-parent_id")
446 .from("child", "parent_id")
447 .to("parent", "id"),
448 )
449 .to_owned()
450 };
451
452 let entities: HashMap<_, _> =
453 EntityTransformer::transform(vec![parent_stmt(), child_stmt()])?
454 .entities
455 .into_iter()
456 .map(|entity| (entity.table_name.clone(), entity))
457 .collect();
458
459 let child = entities.get("child").expect("missing entity `child`");
460 assert_eq!(child.relations.len(), 1);
461 assert_eq!(child.relations[0].ref_table, "parent");
462
463 let entities: HashMap<_, _> = EntityTransformer::transform(vec![child_stmt()])?
464 .entities
465 .into_iter()
466 .map(|entity| (entity.table_name.clone(), entity))
467 .collect();
468
469 let child = entities.get("child").expect("missing entity `child`");
470 assert!(child.relations.is_empty());
471
472 Ok(())
473 }
474
475 #[test]
476 fn filter_conjunct_relations_to_missing_entities() -> Result<(), Box<dyn Error>> {
477 let user_stmt = || {
478 Table::create()
479 .table("user")
480 .col(
481 ColumnDef::new("id")
482 .integer()
483 .not_null()
484 .auto_increment()
485 .primary_key(),
486 )
487 .to_owned()
488 };
489
490 let role_stmt = || {
491 Table::create()
492 .table("role")
493 .col(
494 ColumnDef::new("id")
495 .integer()
496 .not_null()
497 .auto_increment()
498 .primary_key(),
499 )
500 .to_owned()
501 };
502
503 let user_role_stmt = || {
504 Table::create()
505 .table("user_role")
506 .col(ColumnDef::new("user_id").integer().not_null().primary_key())
507 .col(ColumnDef::new("role_id").integer().not_null().primary_key())
508 .foreign_key(
509 ForeignKey::create()
510 .name("fk-user_role-user_id")
511 .from("user_role", "user_id")
512 .to("user", "id"),
513 )
514 .foreign_key(
515 ForeignKey::create()
516 .name("fk-user_role-role_id")
517 .from("user_role", "role_id")
518 .to("role", "id"),
519 )
520 .to_owned()
521 };
522
523 let entities: HashMap<_, _> =
524 EntityTransformer::transform(vec![user_stmt(), role_stmt(), user_role_stmt()])?
525 .entities
526 .into_iter()
527 .map(|entity| (entity.table_name.clone(), entity))
528 .collect();
529
530 let user = entities.get("user").expect("missing entity `user`");
531 assert!(user.conjunct_relations.iter().any(|conjunct_relation| {
532 conjunct_relation.via == "user_role" && conjunct_relation.to == "role"
533 }));
534
535 let entities: HashMap<_, _> =
536 EntityTransformer::transform(vec![user_stmt(), user_role_stmt()])?
537 .entities
538 .into_iter()
539 .map(|entity| (entity.table_name.clone(), entity))
540 .collect();
541
542 let user = entities.get("user").expect("missing entity `user`");
543 assert!(user.conjunct_relations.is_empty());
544
545 let user_role = entities
546 .get("user_role")
547 .expect("missing entity `user_role`");
548 assert_eq!(user_role.relations.len(), 1);
549 assert_eq!(user_role.relations[0].ref_table, "user");
550
551 Ok(())
552 }
553
554 macro_rules! validate_entities_fn {
555 ($fn_name: ident, $method: ident) => {
556 fn $fn_name(
557 table_create_stmts: Vec<TableCreateStatement>,
558 files: Vec<(&str, &str)>,
559 ) -> Result<(), Box<dyn Error>> {
560 let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
561 .entities
562 .into_iter()
563 .map(|entity| (entity.table_name.clone(), entity))
564 .collect();
565
566 for (entity_name, file_content) in files {
567 let entity = entities
568 .get(entity_name)
569 .expect("Forget to add entity to the list");
570
571 assert_eq!(
572 parse_from_file(file_content.as_bytes())?.to_string(),
573 EntityWriter::$method(
574 entity,
575 &crate::WithSerde::None,
576 &Default::default(),
577 &None,
578 false,
579 false,
580 &Default::default(),
581 &Default::default(),
582 &Default::default(),
583 false,
584 true,
585 )
586 .into_iter()
587 .skip(1)
588 .fold(TokenStream::new(), |mut acc, tok| {
589 acc.extend(tok);
590 acc
591 })
592 .to_string()
593 );
594 }
595
596 Ok(())
597 }
598 };
599 }
600
601 validate_entities_fn!(validate_compact_entities, gen_compact_code_blocks);
602 validate_entities_fn!(validate_dense_entities, gen_dense_code_blocks);
603
604 fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
605 where
606 R: io::Read,
607 {
608 let mut reader = BufReader::new(inner);
609 let mut lines: Vec<String> = Vec::new();
610
611 reader.read_until(b';', &mut Vec::new())?;
612
613 let mut line = String::new();
614 while reader.read_line(&mut line)? > 0 {
615 lines.push(line.to_owned());
616 line.clear();
617 }
618 let content = lines.join("");
619 Ok(content.parse().unwrap())
620 }
621}