sea_orm_codegen/entity/
transformer.rs

1use crate::{
2    ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3    RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12    pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13        let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14        let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15        let mut entities = BTreeMap::new();
16        for table_create in table_create_stmts.into_iter() {
17            let table_name = match table_create.get_table_name() {
18                Some(table_ref) => table_ref.sea_orm_table().to_string(),
19                None => {
20                    return Err(Error::TransformError(
21                        "Table name should not be empty".into(),
22                    ));
23                }
24            };
25            let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26            let columns: Vec<Column> = table_create
27                .get_columns()
28                .iter()
29                .map(|col_def| {
30                    let primary_key = col_def.get_column_spec().primary_key;
31                    if primary_key {
32                        primary_keys.push(PrimaryKey {
33                            name: col_def.get_column_name(),
34                        });
35                    }
36                    col_def.into()
37                })
38                .map(|mut col: Column| {
39                    col.unique = table_create
40                        .get_indexes()
41                        .iter()
42                        .filter(|index| index.is_unique_key())
43                        .map(|index| index.get_index_spec().get_column_names())
44                        .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45                        .count()
46                        > 0;
47                    col
48                })
49                .inspect(|col| {
50                    if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51                    {
52                        enums.insert(
53                            name.to_string(),
54                            ActiveEnum {
55                                enum_name: name.clone(),
56                                values: variants.clone(),
57                            },
58                        );
59                    }
60                })
61                .collect();
62            let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
63            let relations: Vec<Relation> = table_create
64                .get_foreign_key_create_stmts()
65                .iter()
66                .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
67                .map(|tbl_fk| {
68                    let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
69                    if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
70                        if *count == 0 {
71                            *count = 1;
72                        }
73                        *count += 1;
74                    } else {
75                        ref_table_counts.insert(ref_tbl, 0);
76                    };
77                    tbl_fk.into()
78                })
79                .collect::<Vec<_>>()
80                .into_iter()
81                .rev()
82                .map(|mut rel: Relation| {
83                    rel.self_referencing = rel.ref_table == table_name;
84                    if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
85                        rel.num_suffix = *count;
86                        if *count > 0 {
87                            *count -= 1;
88                        }
89                    }
90                    rel
91                })
92                .rev()
93                .collect();
94            primary_keys.extend(
95                table_create
96                    .get_indexes()
97                    .iter()
98                    .filter(|index| index.is_primary_key())
99                    .flat_map(|index| {
100                        index
101                            .get_index_spec()
102                            .get_column_names()
103                            .into_iter()
104                            .map(|name| PrimaryKey { name })
105                            .collect::<Vec<_>>()
106                    }),
107            );
108            let entity = Entity {
109                table_name: table_name.clone(),
110                columns,
111                relations: relations.clone(),
112                conjunct_relations: vec![],
113                primary_keys,
114            };
115            entities.insert(table_name.clone(), entity.clone());
116            for mut rel in relations.into_iter() {
117                // This will produce a duplicated relation
118                if rel.self_referencing {
119                    continue;
120                }
121                // This will cause compile error on the many side,
122                // got relation variant but without Related<T> implemented
123                if rel.num_suffix > 0 {
124                    continue;
125                }
126                let ref_table = rel.ref_table;
127                let mut unique = true;
128                for column in rel.columns.iter() {
129                    if !entity
130                        .columns
131                        .iter()
132                        .filter(|col| col.unique)
133                        .any(|col| col.name.as_str() == column)
134                    {
135                        unique = false;
136                        break;
137                    }
138                }
139                if rel.columns.len() == entity.primary_keys.len() {
140                    let mut count_pk = 0;
141                    for primary_key in entity.primary_keys.iter() {
142                        if rel.columns.contains(&primary_key.name) {
143                            count_pk += 1;
144                        }
145                    }
146                    if count_pk == entity.primary_keys.len() {
147                        unique = true;
148                    }
149                }
150                let rel_type = if unique {
151                    RelationType::HasOne
152                } else {
153                    RelationType::HasMany
154                };
155                rel.rel_type = rel_type;
156                rel.ref_table = table_name.to_string();
157                rel.columns = Vec::new();
158                rel.ref_columns = Vec::new();
159                if let Some(vec) = inverse_relations.get_mut(&ref_table) {
160                    vec.push(rel);
161                } else {
162                    inverse_relations.insert(ref_table, vec![rel]);
163                }
164            }
165        }
166        for (tbl_name, relations) in inverse_relations.into_iter() {
167            if let Some(entity) = entities.get_mut(&tbl_name) {
168                for relation in relations.into_iter() {
169                    let duplicate_relation = entity
170                        .relations
171                        .iter()
172                        .any(|rel| rel.ref_table == relation.ref_table);
173                    if !duplicate_relation {
174                        entity.relations.push(relation);
175                    }
176                }
177            }
178        }
179        for table_name in entities.clone().keys() {
180            let relations = match entities.get(table_name) {
181                Some(entity) => {
182                    let is_conjunct_relation =
183                        entity.relations.len() == 2 && entity.primary_keys.len() == 2;
184                    if !is_conjunct_relation {
185                        continue;
186                    }
187                    entity.relations.clone()
188                }
189                None => unreachable!(),
190            };
191            for (i, rel) in relations.iter().enumerate() {
192                let another_rel = relations.get((i == 0) as usize).unwrap();
193                if let Some(entity) = entities.get_mut(&rel.ref_table) {
194                    let conjunct_relation = ConjunctRelation {
195                        via: table_name.clone(),
196                        to: another_rel.ref_table.clone(),
197                    };
198                    entity.conjunct_relations.push(conjunct_relation);
199                }
200            }
201        }
202        Ok(EntityWriter {
203            entities: entities
204                .into_values()
205                .map(|mut v| {
206                    // Filter duplicated conjunct relations
207                    let duplicated_to: Vec<_> = v
208                        .conjunct_relations
209                        .iter()
210                        .fold(HashMap::new(), |mut acc, conjunct_relation| {
211                            acc.entry(conjunct_relation.to.clone())
212                                .and_modify(|c| *c += 1)
213                                .or_insert(1);
214                            acc
215                        })
216                        .into_iter()
217                        .filter(|(_, v)| v > &1)
218                        .map(|(k, _)| k)
219                        .collect();
220                    v.conjunct_relations
221                        .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
222
223                    // Skip `impl Related ... { fn to() ... }` implementation block,
224                    // if the same related entity is being referenced by a conjunct relation
225                    v.relations.iter_mut().for_each(|relation| {
226                        if v.conjunct_relations
227                            .iter()
228                            .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
229                        {
230                            relation.impl_related = false;
231                        }
232                    });
233
234                    // Sort relation vectors
235                    v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
236                    v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
237                    v
238                })
239                .collect(),
240            enums,
241        })
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use pretty_assertions::assert_eq;
249    use proc_macro2::TokenStream;
250    use sea_orm::{DbBackend, Schema};
251    use std::{
252        error::Error,
253        io::{self, BufRead, BufReader},
254    };
255
256    #[test]
257    fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
258        use crate::tests_cfg::duplicated_many_to_many_paths::*;
259        let schema = Schema::new(DbBackend::Postgres);
260
261        validate_compact_entities(
262            vec![
263                schema.create_table_from_entity(bills::Entity),
264                schema.create_table_from_entity(users::Entity),
265                schema.create_table_from_entity(users_saved_bills::Entity),
266                schema.create_table_from_entity(users_votes::Entity),
267            ],
268            vec![
269                (
270                    "bills",
271                    include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
272                ),
273                (
274                    "users",
275                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
276                ),
277                (
278                    "users_saved_bills",
279                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
280                ),
281                (
282                    "users_votes",
283                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
284                ),
285            ],
286        )
287    }
288
289    #[test]
290    fn many_to_many() -> Result<(), Box<dyn Error>> {
291        use crate::tests_cfg::many_to_many::*;
292        let schema = Schema::new(DbBackend::Postgres);
293
294        validate_compact_entities(
295            vec![
296                schema.create_table_from_entity(bills::Entity),
297                schema.create_table_from_entity(users::Entity),
298                schema.create_table_from_entity(users_votes::Entity),
299            ],
300            vec![
301                ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
302                ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
303                (
304                    "users_votes",
305                    include_str!("../tests_cfg/many_to_many/users_votes.rs"),
306                ),
307            ],
308        )
309    }
310
311    #[test]
312    fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
313        use crate::tests_cfg::many_to_many_multiple::*;
314        let schema = Schema::new(DbBackend::Postgres);
315
316        validate_compact_entities(
317            vec![
318                schema.create_table_from_entity(bills::Entity),
319                schema.create_table_from_entity(users::Entity),
320                schema.create_table_from_entity(users_votes::Entity),
321            ],
322            vec![
323                (
324                    "bills",
325                    include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
326                ),
327                (
328                    "users",
329                    include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
330                ),
331                (
332                    "users_votes",
333                    include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
334                ),
335            ],
336        )
337    }
338
339    #[test]
340    fn self_referencing() -> Result<(), Box<dyn Error>> {
341        use crate::tests_cfg::self_referencing::*;
342        let schema = Schema::new(DbBackend::Postgres);
343
344        validate_compact_entities(
345            vec![
346                schema.create_table_from_entity(bills::Entity),
347                schema.create_table_from_entity(users::Entity),
348            ],
349            vec![
350                (
351                    "bills",
352                    include_str!("../tests_cfg/self_referencing/bills.rs"),
353                ),
354                (
355                    "users",
356                    include_str!("../tests_cfg/self_referencing/users.rs"),
357                ),
358            ],
359        )
360    }
361
362    fn validate_compact_entities(
363        table_create_stmts: Vec<TableCreateStatement>,
364        files: Vec<(&str, &str)>,
365    ) -> Result<(), Box<dyn Error>> {
366        let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
367            .entities
368            .into_iter()
369            .map(|entity| (entity.table_name.clone(), entity))
370            .collect();
371
372        for (entity_name, file_content) in files {
373            let entity = entities
374                .get(entity_name)
375                .expect("Forget to add entity to the list");
376
377            assert_eq!(
378                parse_from_file(file_content.as_bytes())?.to_string(),
379                EntityWriter::gen_compact_code_blocks(
380                    entity,
381                    &crate::WithSerde::None,
382                    &crate::DateTimeCrate::Chrono,
383                    &None,
384                    false,
385                    false,
386                    &Default::default(),
387                    &Default::default(),
388                    &Default::default(),
389                    false,
390                    true,
391                )
392                .into_iter()
393                .skip(1)
394                .fold(TokenStream::new(), |mut acc, tok| {
395                    acc.extend(tok);
396                    acc
397                })
398                .to_string()
399            );
400        }
401
402        Ok(())
403    }
404
405    fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
406    where
407        R: io::Read,
408    {
409        let mut reader = BufReader::new(inner);
410        let mut lines: Vec<String> = Vec::new();
411
412        reader.read_until(b';', &mut Vec::new())?;
413
414        let mut line = String::new();
415        while reader.read_line(&mut line)? > 0 {
416            lines.push(line.to_owned());
417            line.clear();
418        }
419        let content = lines.join("");
420        Ok(content.parse().unwrap())
421    }
422}