sea_orm_codegen/entity/
transformer.rs

1use crate::{
2    ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3    RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12    pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13        let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14        let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15        let mut entities = BTreeMap::new();
16        for table_create in table_create_stmts.into_iter() {
17            let table_name = match table_create.get_table_name() {
18                Some(table_ref) => table_ref.sea_orm_table().to_string(),
19                None => {
20                    return Err(Error::TransformError(
21                        "Table name should not be empty".into(),
22                    ));
23                }
24            };
25            let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26            let mut columns: Vec<Column> = table_create
27                .get_columns()
28                .iter()
29                .map(|col_def| {
30                    let primary_key = col_def.get_column_spec().primary_key;
31                    if primary_key {
32                        primary_keys.push(PrimaryKey {
33                            name: col_def.get_column_name(),
34                        });
35                    }
36                    col_def.into()
37                })
38                .map(|mut col: Column| {
39                    col.unique = table_create
40                        .get_indexes()
41                        .iter()
42                        .filter(|index| index.is_unique_key())
43                        .map(|index| index.get_index_spec().get_column_names())
44                        .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45                        .count()
46                        > 0;
47                    col
48                })
49                .inspect(|col| {
50                    if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51                    {
52                        enums.insert(
53                            name.to_string(),
54                            ActiveEnum {
55                                enum_name: name.clone(),
56                                values: variants.clone(),
57                            },
58                        );
59                    }
60                })
61                .collect();
62            for index in table_create.get_indexes().iter() {
63                if index.is_unique_key() {
64                    let col_names = index.get_index_spec().get_column_names();
65                    if col_names.len() > 1 {
66                        if let Some(mut key_name) = index.get_index_spec().get_name() {
67                            if let Some((_, suffix)) = key_name.rsplit_once('-') {
68                                key_name = suffix;
69                            }
70                            for col_name in col_names {
71                                for column in columns.iter_mut() {
72                                    if column.name == col_name {
73                                        column.unique_key = Some(key_name.to_owned());
74                                    }
75                                }
76                            }
77                        }
78                    }
79                }
80            }
81            let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
82            let relations: Vec<Relation> = table_create
83                .get_foreign_key_create_stmts()
84                .iter()
85                .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
86                .map(|tbl_fk| {
87                    let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
88                    if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
89                        if *count == 0 {
90                            *count = 1;
91                        }
92                        *count += 1;
93                    } else {
94                        ref_table_counts.insert(ref_tbl, 0);
95                    };
96                    tbl_fk.into()
97                })
98                .collect::<Vec<_>>()
99                .into_iter()
100                .rev()
101                .map(|mut rel: Relation| {
102                    rel.self_referencing = rel.ref_table == table_name;
103                    if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
104                        rel.num_suffix = *count;
105                        if *count > 0 {
106                            *count -= 1;
107                        }
108                    }
109                    rel
110                })
111                .rev()
112                .collect();
113            primary_keys.extend(
114                table_create
115                    .get_indexes()
116                    .iter()
117                    .filter(|index| index.is_primary_key())
118                    .flat_map(|index| {
119                        index
120                            .get_index_spec()
121                            .get_column_names()
122                            .into_iter()
123                            .map(|name| PrimaryKey { name })
124                            .collect::<Vec<_>>()
125                    }),
126            );
127            let entity = Entity {
128                table_name: table_name.clone(),
129                columns,
130                relations: relations.clone(),
131                conjunct_relations: vec![],
132                primary_keys,
133            };
134            entities.insert(table_name.clone(), entity.clone());
135            for mut rel in relations.into_iter() {
136                // This will produce a duplicated relation
137                if rel.self_referencing {
138                    continue;
139                }
140                // This will cause compile error on the many side,
141                // got relation variant but without Related<T> implemented
142                if rel.num_suffix > 0 {
143                    continue;
144                }
145                let ref_table = rel.ref_table;
146                let mut unique = true;
147                for column in rel.columns.iter() {
148                    if !entity
149                        .columns
150                        .iter()
151                        .filter(|col| col.unique)
152                        .any(|col| col.name.as_str() == column)
153                    {
154                        unique = false;
155                        break;
156                    }
157                }
158                if rel.columns.len() == entity.primary_keys.len() {
159                    let mut count_pk = 0;
160                    for primary_key in entity.primary_keys.iter() {
161                        if rel.columns.contains(&primary_key.name) {
162                            count_pk += 1;
163                        }
164                    }
165                    if count_pk == entity.primary_keys.len() {
166                        unique = true;
167                    }
168                }
169                let rel_type = if unique {
170                    RelationType::HasOne
171                } else {
172                    RelationType::HasMany
173                };
174                rel.rel_type = rel_type;
175                rel.ref_table = table_name.to_string();
176                rel.columns = Vec::new();
177                rel.ref_columns = Vec::new();
178                if let Some(vec) = inverse_relations.get_mut(&ref_table) {
179                    vec.push(rel);
180                } else {
181                    inverse_relations.insert(ref_table, vec![rel]);
182                }
183            }
184        }
185        for (tbl_name, relations) in inverse_relations.into_iter() {
186            if let Some(entity) = entities.get_mut(&tbl_name) {
187                for relation in relations.into_iter() {
188                    let duplicate_relation = entity
189                        .relations
190                        .iter()
191                        .any(|rel| rel.ref_table == relation.ref_table);
192                    if !duplicate_relation {
193                        entity.relations.push(relation);
194                    }
195                }
196            }
197        }
198        for table_name in entities.clone().keys() {
199            let relations = match entities.get(table_name) {
200                Some(entity) => {
201                    let is_conjunct_relation =
202                        entity.relations.len() == 2 && entity.primary_keys.len() == 2;
203                    if !is_conjunct_relation {
204                        continue;
205                    }
206                    entity.relations.clone()
207                }
208                None => unreachable!(),
209            };
210            for (i, rel) in relations.iter().enumerate() {
211                let another_rel = relations.get((i == 0) as usize).unwrap();
212                if let Some(entity) = entities.get_mut(&rel.ref_table) {
213                    let conjunct_relation = ConjunctRelation {
214                        via: table_name.clone(),
215                        to: another_rel.ref_table.clone(),
216                    };
217                    entity.conjunct_relations.push(conjunct_relation);
218                }
219            }
220        }
221        Ok(EntityWriter {
222            entities: entities
223                .into_values()
224                .map(|mut v| {
225                    // Filter duplicated conjunct relations
226                    let duplicated_to: Vec<_> = v
227                        .conjunct_relations
228                        .iter()
229                        .fold(HashMap::new(), |mut acc, conjunct_relation| {
230                            acc.entry(conjunct_relation.to.clone())
231                                .and_modify(|c| *c += 1)
232                                .or_insert(1);
233                            acc
234                        })
235                        .into_iter()
236                        .filter(|(_, v)| v > &1)
237                        .map(|(k, _)| k)
238                        .collect();
239                    v.conjunct_relations
240                        .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
241
242                    // Skip `impl Related ... { fn to() ... }` implementation block,
243                    // if the same related entity is being referenced by a conjunct relation
244                    v.relations.iter_mut().for_each(|relation| {
245                        if v.conjunct_relations
246                            .iter()
247                            .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
248                        {
249                            relation.impl_related = false;
250                        }
251                    });
252
253                    // Sort relation vectors
254                    v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
255                    v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
256                    v
257                })
258                .collect(),
259            enums,
260        })
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use pretty_assertions::assert_eq;
268    use proc_macro2::TokenStream;
269    use sea_orm::{DbBackend, Schema};
270    use std::{
271        error::Error,
272        io::{self, BufRead, BufReader},
273    };
274
275    #[test]
276    fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
277        use crate::tests_cfg::duplicated_many_to_many_paths::*;
278        let schema = Schema::new(DbBackend::Postgres);
279
280        validate_compact_entities(
281            vec![
282                schema.create_table_from_entity(bills::Entity),
283                schema.create_table_from_entity(users::Entity),
284                schema.create_table_from_entity(users_saved_bills::Entity),
285                schema.create_table_from_entity(users_votes::Entity),
286            ],
287            vec![
288                (
289                    "bills",
290                    include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
291                ),
292                (
293                    "users",
294                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
295                ),
296                (
297                    "users_saved_bills",
298                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
299                ),
300                (
301                    "users_votes",
302                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
303                ),
304            ],
305        )
306    }
307
308    #[test]
309    fn many_to_many() -> Result<(), Box<dyn Error>> {
310        use crate::tests_cfg::many_to_many::*;
311        let schema = Schema::new(DbBackend::Postgres);
312
313        validate_compact_entities(
314            vec![
315                schema.create_table_from_entity(bills::Entity),
316                schema.create_table_from_entity(users::Entity),
317                schema.create_table_from_entity(users_votes::Entity),
318            ],
319            vec![
320                ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
321                ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
322                (
323                    "users_votes",
324                    include_str!("../tests_cfg/many_to_many/users_votes.rs"),
325                ),
326            ],
327        )
328    }
329
330    #[test]
331    fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
332        use crate::tests_cfg::many_to_many_multiple::*;
333        let schema = Schema::new(DbBackend::Postgres);
334
335        validate_compact_entities(
336            vec![
337                schema.create_table_from_entity(bills::Entity),
338                schema.create_table_from_entity(users::Entity),
339                schema.create_table_from_entity(users_votes::Entity),
340            ],
341            vec![
342                (
343                    "bills",
344                    include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
345                ),
346                (
347                    "users",
348                    include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
349                ),
350                (
351                    "users_votes",
352                    include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
353                ),
354            ],
355        )
356    }
357
358    #[test]
359    fn self_referencing() -> Result<(), Box<dyn Error>> {
360        use crate::tests_cfg::self_referencing::*;
361        let schema = Schema::new(DbBackend::Postgres);
362
363        validate_compact_entities(
364            vec![
365                schema.create_table_from_entity(bills::Entity),
366                schema.create_table_from_entity(users::Entity),
367            ],
368            vec![
369                (
370                    "bills",
371                    include_str!("../tests_cfg/self_referencing/bills.rs"),
372                ),
373                (
374                    "users",
375                    include_str!("../tests_cfg/self_referencing/users.rs"),
376                ),
377            ],
378        )
379    }
380
381    #[test]
382    fn test_indexes_transform() -> Result<(), Box<dyn Error>> {
383        let schema = Schema::new(DbBackend::Postgres);
384
385        validate_compact_entities(
386            vec![
387                schema.create_table_with_index_from_entity(
388                    crate::tests_cfg::compact::indexes::Entity,
389                ),
390            ],
391            vec![("indexes", include_str!("../tests_cfg/compact/indexes.rs"))],
392        )?;
393
394        validate_dense_entities(
395            vec![
396                schema
397                    .create_table_with_index_from_entity(crate::tests_cfg::dense::indexes::Entity),
398            ],
399            vec![("indexes", include_str!("../tests_cfg/dense/indexes.rs"))],
400        )?;
401
402        Ok(())
403    }
404
405    macro_rules! validate_entities_fn {
406        ($fn_name: ident, $method: ident) => {
407            fn $fn_name(
408                table_create_stmts: Vec<TableCreateStatement>,
409                files: Vec<(&str, &str)>,
410            ) -> Result<(), Box<dyn Error>> {
411                let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
412                    .entities
413                    .into_iter()
414                    .map(|entity| (entity.table_name.clone(), entity))
415                    .collect();
416
417                for (entity_name, file_content) in files {
418                    let entity = entities
419                        .get(entity_name)
420                        .expect("Forget to add entity to the list");
421
422                    assert_eq!(
423                        parse_from_file(file_content.as_bytes())?.to_string(),
424                        EntityWriter::$method(
425                            entity,
426                            &crate::WithSerde::None,
427                            &crate::DateTimeCrate::Chrono,
428                            &None,
429                            false,
430                            false,
431                            &Default::default(),
432                            &Default::default(),
433                            &Default::default(),
434                            false,
435                            true,
436                        )
437                        .into_iter()
438                        .skip(1)
439                        .fold(TokenStream::new(), |mut acc, tok| {
440                            acc.extend(tok);
441                            acc
442                        })
443                        .to_string()
444                    );
445                }
446
447                Ok(())
448            }
449        };
450    }
451
452    validate_entities_fn!(validate_compact_entities, gen_compact_code_blocks);
453    validate_entities_fn!(validate_dense_entities, gen_dense_code_blocks);
454
455    fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
456    where
457        R: io::Read,
458    {
459        let mut reader = BufReader::new(inner);
460        let mut lines: Vec<String> = Vec::new();
461
462        reader.read_until(b';', &mut Vec::new())?;
463
464        let mut line = String::new();
465        while reader.read_line(&mut line)? > 0 {
466            lines.push(line.to_owned());
467            line.clear();
468        }
469        let content = lines.join("");
470        Ok(content.parse().unwrap())
471    }
472}