Skip to main content

sea_orm_codegen/entity/
transformer.rs

1use crate::{
2    ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3    RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap, HashSet};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12    pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13        let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14        let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15        let mut entities = BTreeMap::new();
16        for table_create in table_create_stmts.into_iter() {
17            let table_name = match table_create.get_table_name() {
18                Some(table_ref) => table_ref.sea_orm_table().to_string(),
19                None => {
20                    return Err(Error::TransformError(
21                        "Table name should not be empty".into(),
22                    ));
23                }
24            };
25            let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26            let mut columns: Vec<Column> = table_create
27                .get_columns()
28                .iter()
29                .map(|col_def| {
30                    let primary_key = col_def.get_column_spec().primary_key;
31                    if primary_key {
32                        primary_keys.push(PrimaryKey {
33                            name: col_def.get_column_name(),
34                        });
35                    }
36                    col_def.into()
37                })
38                .map(|mut col: Column| {
39                    col.unique = table_create
40                        .get_indexes()
41                        .iter()
42                        .filter(|index| index.is_unique_key())
43                        .map(|index| index.get_index_spec().get_column_names())
44                        .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45                        .count()
46                        > 0;
47                    col
48                })
49                .inspect(|col| {
50                    if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51                    {
52                        enums.insert(
53                            name.to_string(),
54                            ActiveEnum {
55                                enum_name: name.clone(),
56                                values: variants.clone(),
57                            },
58                        );
59                    }
60                })
61                .collect();
62            for index in table_create.get_indexes().iter() {
63                if index.is_unique_key() {
64                    let col_names = index.get_index_spec().get_column_names();
65                    if col_names.len() > 1 {
66                        if let Some(mut key_name) = index.get_index_spec().get_name() {
67                            if let Some((_, suffix)) = key_name.rsplit_once('-') {
68                                key_name = suffix;
69                            }
70                            for col_name in col_names {
71                                for column in columns.iter_mut() {
72                                    if column.name == col_name {
73                                        column.unique_key = Some(key_name.to_owned());
74                                    }
75                                }
76                            }
77                        }
78                    }
79                }
80            }
81            let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
82            let relations: Vec<Relation> = table_create
83                .get_foreign_key_create_stmts()
84                .iter()
85                .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
86                .map(|tbl_fk| {
87                    let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
88                    if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
89                        if *count == 0 {
90                            *count = 1;
91                        }
92                        *count += 1;
93                    } else {
94                        ref_table_counts.insert(ref_tbl, 0);
95                    };
96                    tbl_fk.into()
97                })
98                .collect::<Vec<_>>()
99                .into_iter()
100                .rev()
101                .map(|mut rel: Relation| {
102                    rel.self_referencing = rel.ref_table == table_name;
103                    if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
104                        rel.num_suffix = *count;
105                        if *count > 0 {
106                            *count -= 1;
107                        }
108                    }
109                    rel
110                })
111                .rev()
112                .collect();
113            primary_keys.extend(
114                table_create
115                    .get_indexes()
116                    .iter()
117                    .filter(|index| index.is_primary_key())
118                    .flat_map(|index| {
119                        index
120                            .get_index_spec()
121                            .get_column_names()
122                            .into_iter()
123                            .map(|name| PrimaryKey { name })
124                            .collect::<Vec<_>>()
125                    }),
126            );
127            let entity = Entity {
128                table_name: table_name.clone(),
129                columns,
130                relations: relations.clone(),
131                conjunct_relations: vec![],
132                primary_keys,
133            };
134            entities.insert(table_name.clone(), entity.clone());
135            for mut rel in relations.into_iter() {
136                // This will produce a duplicated relation
137                if rel.self_referencing {
138                    continue;
139                }
140                // This will cause compile error on the many side,
141                // got relation variant but without Related<T> implemented
142                if rel.num_suffix > 0 {
143                    continue;
144                }
145                let ref_table = rel.ref_table;
146                let mut unique = true;
147                for column in rel.columns.iter() {
148                    if !entity
149                        .columns
150                        .iter()
151                        .filter(|col| col.unique)
152                        .any(|col| col.name.as_str() == column)
153                    {
154                        unique = false;
155                        break;
156                    }
157                }
158                if rel.columns.len() == entity.primary_keys.len() {
159                    let mut count_pk = 0;
160                    for primary_key in entity.primary_keys.iter() {
161                        if rel.columns.contains(&primary_key.name) {
162                            count_pk += 1;
163                        }
164                    }
165                    if count_pk == entity.primary_keys.len() {
166                        unique = true;
167                    }
168                }
169                let rel_type = if unique {
170                    RelationType::HasOne
171                } else {
172                    RelationType::HasMany
173                };
174                rel.rel_type = rel_type;
175                rel.ref_table = table_name.to_string();
176                rel.columns = Vec::new();
177                rel.ref_columns = Vec::new();
178                if let Some(vec) = inverse_relations.get_mut(&ref_table) {
179                    vec.push(rel);
180                } else {
181                    inverse_relations.insert(ref_table, vec![rel]);
182                }
183            }
184        }
185        for (tbl_name, relations) in inverse_relations.into_iter() {
186            if let Some(entity) = entities.get_mut(&tbl_name) {
187                for relation in relations.into_iter() {
188                    let duplicate_relation = entity
189                        .relations
190                        .iter()
191                        .any(|rel| rel.ref_table == relation.ref_table);
192                    if !duplicate_relation {
193                        entity.relations.push(relation);
194                    }
195                }
196            }
197        }
198
199        // When codegen is fed with a subset of tables (e.g. via `sea-orm-cli generate entity --tables`),
200        // we must not generate relations that point to entities outside this set, otherwise it will
201        // produce invalid paths like `super::<missing_table>::Entity`.
202        let table_names: HashSet<String> = entities.keys().cloned().collect();
203        for entity in entities.values_mut() {
204            entity
205                .relations
206                .retain(|rel| rel.self_referencing || table_names.contains(&rel.ref_table));
207        }
208
209        for table_name in entities.clone().keys() {
210            let relations = match entities.get(table_name) {
211                Some(entity) => {
212                    let is_conjunct_relation =
213                        entity.relations.len() == 2 && entity.primary_keys.len() == 2;
214                    if !is_conjunct_relation {
215                        continue;
216                    }
217                    entity.relations.clone()
218                }
219                None => unreachable!(),
220            };
221            for (i, rel) in relations.iter().enumerate() {
222                let another_rel = relations.get((i == 0) as usize).unwrap();
223                if let Some(entity) = entities.get_mut(&rel.ref_table) {
224                    let conjunct_relation = ConjunctRelation {
225                        via: table_name.clone(),
226                        to: another_rel.ref_table.clone(),
227                    };
228                    entity.conjunct_relations.push(conjunct_relation);
229                }
230            }
231        }
232        Ok(EntityWriter {
233            entities: entities
234                .into_values()
235                .map(|mut v| {
236                    // Filter duplicated conjunct relations
237                    let duplicated_to: Vec<_> = v
238                        .conjunct_relations
239                        .iter()
240                        .fold(HashMap::new(), |mut acc, conjunct_relation| {
241                            acc.entry(conjunct_relation.to.clone())
242                                .and_modify(|c| *c += 1)
243                                .or_insert(1);
244                            acc
245                        })
246                        .into_iter()
247                        .filter(|(_, v)| v > &1)
248                        .map(|(k, _)| k)
249                        .collect();
250                    v.conjunct_relations
251                        .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
252
253                    // Skip `impl Related ... { fn to() ... }` implementation block,
254                    // if the same related entity is being referenced by a conjunct relation
255                    v.relations.iter_mut().for_each(|relation| {
256                        if v.conjunct_relations
257                            .iter()
258                            .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
259                        {
260                            relation.impl_related = false;
261                        }
262                    });
263
264                    // Sort relation vectors
265                    v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
266                    v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
267                    v
268                })
269                .collect(),
270            enums,
271        })
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use pretty_assertions::assert_eq;
279    use proc_macro2::TokenStream;
280    use sea_orm::{DbBackend, Schema};
281    use sea_query::{ColumnDef, ForeignKey, Table};
282    use std::{
283        error::Error,
284        io::{self, BufRead, BufReader},
285    };
286
287    #[test]
288    fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
289        use crate::tests_cfg::duplicated_many_to_many_paths::*;
290        let schema = Schema::new(DbBackend::Postgres);
291
292        validate_compact_entities(
293            vec![
294                schema.create_table_from_entity(bills::Entity),
295                schema.create_table_from_entity(users::Entity),
296                schema.create_table_from_entity(users_saved_bills::Entity),
297                schema.create_table_from_entity(users_votes::Entity),
298            ],
299            vec![
300                (
301                    "bills",
302                    include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
303                ),
304                (
305                    "users",
306                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
307                ),
308                (
309                    "users_saved_bills",
310                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
311                ),
312                (
313                    "users_votes",
314                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
315                ),
316            ],
317        )
318    }
319
320    #[test]
321    fn many_to_many() -> Result<(), Box<dyn Error>> {
322        use crate::tests_cfg::many_to_many::*;
323        let schema = Schema::new(DbBackend::Postgres);
324
325        validate_compact_entities(
326            vec![
327                schema.create_table_from_entity(bills::Entity),
328                schema.create_table_from_entity(users::Entity),
329                schema.create_table_from_entity(users_votes::Entity),
330            ],
331            vec![
332                ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
333                ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
334                (
335                    "users_votes",
336                    include_str!("../tests_cfg/many_to_many/users_votes.rs"),
337                ),
338            ],
339        )
340    }
341
342    #[test]
343    fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
344        use crate::tests_cfg::many_to_many_multiple::*;
345        let schema = Schema::new(DbBackend::Postgres);
346
347        validate_compact_entities(
348            vec![
349                schema.create_table_from_entity(bills::Entity),
350                schema.create_table_from_entity(users::Entity),
351                schema.create_table_from_entity(users_votes::Entity),
352            ],
353            vec![
354                (
355                    "bills",
356                    include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
357                ),
358                (
359                    "users",
360                    include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
361                ),
362                (
363                    "users_votes",
364                    include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
365                ),
366            ],
367        )
368    }
369
370    #[test]
371    fn self_referencing() -> Result<(), Box<dyn Error>> {
372        use crate::tests_cfg::self_referencing::*;
373        let schema = Schema::new(DbBackend::Postgres);
374
375        validate_compact_entities(
376            vec![
377                schema.create_table_from_entity(bills::Entity),
378                schema.create_table_from_entity(users::Entity),
379            ],
380            vec![
381                (
382                    "bills",
383                    include_str!("../tests_cfg/self_referencing/bills.rs"),
384                ),
385                (
386                    "users",
387                    include_str!("../tests_cfg/self_referencing/users.rs"),
388                ),
389            ],
390        )
391    }
392
393    #[test]
394    fn test_indexes_transform() -> Result<(), Box<dyn Error>> {
395        let schema = Schema::new(DbBackend::Postgres);
396
397        validate_compact_entities(
398            vec![
399                schema.create_table_with_index_from_entity(
400                    crate::tests_cfg::compact::indexes::Entity,
401                ),
402            ],
403            vec![("indexes", include_str!("../tests_cfg/compact/indexes.rs"))],
404        )?;
405
406        validate_dense_entities(
407            vec![
408                schema
409                    .create_table_with_index_from_entity(crate::tests_cfg::dense::indexes::Entity),
410            ],
411            vec![("indexes", include_str!("../tests_cfg/dense/indexes.rs"))],
412        )?;
413
414        Ok(())
415    }
416
417    #[test]
418    fn filter_relations_to_missing_entities() -> Result<(), Box<dyn Error>> {
419        let parent_stmt = || {
420            Table::create()
421                .table("parent")
422                .col(
423                    ColumnDef::new("id")
424                        .integer()
425                        .not_null()
426                        .auto_increment()
427                        .primary_key(),
428                )
429                .to_owned()
430        };
431
432        let child_stmt = || {
433            Table::create()
434                .table("child")
435                .col(
436                    ColumnDef::new("id")
437                        .integer()
438                        .not_null()
439                        .auto_increment()
440                        .primary_key(),
441                )
442                .col(ColumnDef::new("parent_id").integer().not_null())
443                .foreign_key(
444                    ForeignKey::create()
445                        .name("fk-child-parent_id")
446                        .from("child", "parent_id")
447                        .to("parent", "id"),
448                )
449                .to_owned()
450        };
451
452        let entities: HashMap<_, _> =
453            EntityTransformer::transform(vec![parent_stmt(), child_stmt()])?
454                .entities
455                .into_iter()
456                .map(|entity| (entity.table_name.clone(), entity))
457                .collect();
458
459        let child = entities.get("child").expect("missing entity `child`");
460        assert_eq!(child.relations.len(), 1);
461        assert_eq!(child.relations[0].ref_table, "parent");
462
463        let entities: HashMap<_, _> = EntityTransformer::transform(vec![child_stmt()])?
464            .entities
465            .into_iter()
466            .map(|entity| (entity.table_name.clone(), entity))
467            .collect();
468
469        let child = entities.get("child").expect("missing entity `child`");
470        assert!(child.relations.is_empty());
471
472        Ok(())
473    }
474
475    #[test]
476    fn filter_conjunct_relations_to_missing_entities() -> Result<(), Box<dyn Error>> {
477        let user_stmt = || {
478            Table::create()
479                .table("user")
480                .col(
481                    ColumnDef::new("id")
482                        .integer()
483                        .not_null()
484                        .auto_increment()
485                        .primary_key(),
486                )
487                .to_owned()
488        };
489
490        let role_stmt = || {
491            Table::create()
492                .table("role")
493                .col(
494                    ColumnDef::new("id")
495                        .integer()
496                        .not_null()
497                        .auto_increment()
498                        .primary_key(),
499                )
500                .to_owned()
501        };
502
503        let user_role_stmt = || {
504            Table::create()
505                .table("user_role")
506                .col(ColumnDef::new("user_id").integer().not_null().primary_key())
507                .col(ColumnDef::new("role_id").integer().not_null().primary_key())
508                .foreign_key(
509                    ForeignKey::create()
510                        .name("fk-user_role-user_id")
511                        .from("user_role", "user_id")
512                        .to("user", "id"),
513                )
514                .foreign_key(
515                    ForeignKey::create()
516                        .name("fk-user_role-role_id")
517                        .from("user_role", "role_id")
518                        .to("role", "id"),
519                )
520                .to_owned()
521        };
522
523        let entities: HashMap<_, _> =
524            EntityTransformer::transform(vec![user_stmt(), role_stmt(), user_role_stmt()])?
525                .entities
526                .into_iter()
527                .map(|entity| (entity.table_name.clone(), entity))
528                .collect();
529
530        let user = entities.get("user").expect("missing entity `user`");
531        assert!(user.conjunct_relations.iter().any(|conjunct_relation| {
532            conjunct_relation.via == "user_role" && conjunct_relation.to == "role"
533        }));
534
535        let entities: HashMap<_, _> =
536            EntityTransformer::transform(vec![user_stmt(), user_role_stmt()])?
537                .entities
538                .into_iter()
539                .map(|entity| (entity.table_name.clone(), entity))
540                .collect();
541
542        let user = entities.get("user").expect("missing entity `user`");
543        assert!(user.conjunct_relations.is_empty());
544
545        let user_role = entities
546            .get("user_role")
547            .expect("missing entity `user_role`");
548        assert_eq!(user_role.relations.len(), 1);
549        assert_eq!(user_role.relations[0].ref_table, "user");
550
551        Ok(())
552    }
553
554    macro_rules! validate_entities_fn {
555        ($fn_name: ident, $method: ident) => {
556            fn $fn_name(
557                table_create_stmts: Vec<TableCreateStatement>,
558                files: Vec<(&str, &str)>,
559            ) -> Result<(), Box<dyn Error>> {
560                let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
561                    .entities
562                    .into_iter()
563                    .map(|entity| (entity.table_name.clone(), entity))
564                    .collect();
565
566                for (entity_name, file_content) in files {
567                    let entity = entities
568                        .get(entity_name)
569                        .expect("Forget to add entity to the list");
570
571                    assert_eq!(
572                        parse_from_file(file_content.as_bytes())?.to_string(),
573                        EntityWriter::$method(
574                            entity,
575                            &crate::WithSerde::None,
576                            &Default::default(),
577                            &None,
578                            false,
579                            false,
580                            &Default::default(),
581                            &Default::default(),
582                            &Default::default(),
583                            false,
584                            true,
585                        )
586                        .into_iter()
587                        .skip(1)
588                        .fold(TokenStream::new(), |mut acc, tok| {
589                            acc.extend(tok);
590                            acc
591                        })
592                        .to_string()
593                    );
594                }
595
596                Ok(())
597            }
598        };
599    }
600
601    validate_entities_fn!(validate_compact_entities, gen_compact_code_blocks);
602    validate_entities_fn!(validate_dense_entities, gen_dense_code_blocks);
603
604    fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
605    where
606        R: io::Read,
607    {
608        let mut reader = BufReader::new(inner);
609        let mut lines: Vec<String> = Vec::new();
610
611        reader.read_until(b';', &mut Vec::new())?;
612
613        let mut line = String::new();
614        while reader.read_line(&mut line)? > 0 {
615            lines.push(line.to_owned());
616            line.clear();
617        }
618        let content = lines.join("");
619        Ok(content.parse().unwrap())
620    }
621}