sea_orm_codegen/entity/
transformer.rs

1use crate::{
2    ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3    RelationType,
4};
5use sea_query::{ColumnSpec, TableCreateStatement};
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12    pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13        let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14        let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15        let mut entities = BTreeMap::new();
16        for table_create in table_create_stmts.into_iter() {
17            let table_name = match table_create.get_table_name() {
18                Some(table_ref) => table_ref.sea_orm_table().to_string(),
19                None => {
20                    return Err(Error::TransformError(
21                        "Table name should not be empty".into(),
22                    ));
23                }
24            };
25            let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26            let columns: Vec<Column> = table_create
27                .get_columns()
28                .iter()
29                .map(|col_def| {
30                    let primary_key = col_def
31                        .get_column_spec()
32                        .iter()
33                        .any(|spec| matches!(spec, ColumnSpec::PrimaryKey));
34                    if primary_key {
35                        primary_keys.push(PrimaryKey {
36                            name: col_def.get_column_name(),
37                        });
38                    }
39                    col_def.into()
40                })
41                .map(|mut col: Column| {
42                    col.unique = table_create
43                        .get_indexes()
44                        .iter()
45                        .filter(|index| index.is_unique_key())
46                        .map(|index| index.get_index_spec().get_column_names())
47                        .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
48                        .count()
49                        > 0;
50                    col
51                })
52                .inspect(|col| {
53                    if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
54                    {
55                        enums.insert(
56                            name.to_string(),
57                            ActiveEnum {
58                                enum_name: name.clone(),
59                                values: variants.clone(),
60                            },
61                        );
62                    }
63                })
64                .collect();
65            let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
66            let relations: Vec<Relation> = table_create
67                .get_foreign_key_create_stmts()
68                .iter()
69                .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
70                .map(|tbl_fk| {
71                    let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
72                    if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
73                        if *count == 0 {
74                            *count = 1;
75                        }
76                        *count += 1;
77                    } else {
78                        ref_table_counts.insert(ref_tbl, 0);
79                    };
80                    tbl_fk.into()
81                })
82                .collect::<Vec<_>>()
83                .into_iter()
84                .rev()
85                .map(|mut rel: Relation| {
86                    rel.self_referencing = rel.ref_table == table_name;
87                    if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
88                        rel.num_suffix = *count;
89                        if *count > 0 {
90                            *count -= 1;
91                        }
92                    }
93                    rel
94                })
95                .rev()
96                .collect();
97            primary_keys.extend(
98                table_create
99                    .get_indexes()
100                    .iter()
101                    .filter(|index| index.is_primary_key())
102                    .flat_map(|index| {
103                        index
104                            .get_index_spec()
105                            .get_column_names()
106                            .into_iter()
107                            .map(|name| PrimaryKey { name })
108                            .collect::<Vec<_>>()
109                    }),
110            );
111            let entity = Entity {
112                table_name: table_name.clone(),
113                columns,
114                relations: relations.clone(),
115                conjunct_relations: vec![],
116                primary_keys,
117            };
118            entities.insert(table_name.clone(), entity.clone());
119            for mut rel in relations.into_iter() {
120                // This will produce a duplicated relation
121                if rel.self_referencing {
122                    continue;
123                }
124                // This will cause compile error on the many side,
125                // got relation variant but without Related<T> implemented
126                if rel.num_suffix > 0 {
127                    continue;
128                }
129                let ref_table = rel.ref_table;
130                let mut unique = true;
131                for column in rel.columns.iter() {
132                    if !entity
133                        .columns
134                        .iter()
135                        .filter(|col| col.unique)
136                        .any(|col| col.name.as_str() == column)
137                    {
138                        unique = false;
139                        break;
140                    }
141                }
142                if rel.columns.len() == entity.primary_keys.len() {
143                    let mut count_pk = 0;
144                    for primary_key in entity.primary_keys.iter() {
145                        if rel.columns.contains(&primary_key.name) {
146                            count_pk += 1;
147                        }
148                    }
149                    if count_pk == entity.primary_keys.len() {
150                        unique = true;
151                    }
152                }
153                let rel_type = if unique {
154                    RelationType::HasOne
155                } else {
156                    RelationType::HasMany
157                };
158                rel.rel_type = rel_type;
159                rel.ref_table = table_name.to_string();
160                rel.columns = Vec::new();
161                rel.ref_columns = Vec::new();
162                if let Some(vec) = inverse_relations.get_mut(&ref_table) {
163                    vec.push(rel);
164                } else {
165                    inverse_relations.insert(ref_table, vec![rel]);
166                }
167            }
168        }
169        for (tbl_name, relations) in inverse_relations.into_iter() {
170            if let Some(entity) = entities.get_mut(&tbl_name) {
171                for relation in relations.into_iter() {
172                    let duplicate_relation = entity
173                        .relations
174                        .iter()
175                        .any(|rel| rel.ref_table == relation.ref_table);
176                    if !duplicate_relation {
177                        entity.relations.push(relation);
178                    }
179                }
180            }
181        }
182        for table_name in entities.clone().keys() {
183            let relations = match entities.get(table_name) {
184                Some(entity) => {
185                    let is_conjunct_relation =
186                        entity.relations.len() == 2 && entity.primary_keys.len() == 2;
187                    if !is_conjunct_relation {
188                        continue;
189                    }
190                    entity.relations.clone()
191                }
192                None => unreachable!(),
193            };
194            for (i, rel) in relations.iter().enumerate() {
195                let another_rel = relations.get((i == 0) as usize).unwrap();
196                if let Some(entity) = entities.get_mut(&rel.ref_table) {
197                    let conjunct_relation = ConjunctRelation {
198                        via: table_name.clone(),
199                        to: another_rel.ref_table.clone(),
200                    };
201                    entity.conjunct_relations.push(conjunct_relation);
202                }
203            }
204        }
205        Ok(EntityWriter {
206            entities: entities
207                .into_values()
208                .map(|mut v| {
209                    // Filter duplicated conjunct relations
210                    let duplicated_to: Vec<_> = v
211                        .conjunct_relations
212                        .iter()
213                        .fold(HashMap::new(), |mut acc, conjunct_relation| {
214                            acc.entry(conjunct_relation.to.clone())
215                                .and_modify(|c| *c += 1)
216                                .or_insert(1);
217                            acc
218                        })
219                        .into_iter()
220                        .filter(|(_, v)| v > &1)
221                        .map(|(k, _)| k)
222                        .collect();
223                    v.conjunct_relations
224                        .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
225
226                    // Skip `impl Related ... { fn to() ... }` implementation block,
227                    // if the same related entity is being referenced by a conjunct relation
228                    v.relations.iter_mut().for_each(|relation| {
229                        if v.conjunct_relations
230                            .iter()
231                            .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
232                        {
233                            relation.impl_related = false;
234                        }
235                    });
236
237                    // Sort relation vectors
238                    v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
239                    v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
240                    v
241                })
242                .collect(),
243            enums,
244        })
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use pretty_assertions::assert_eq;
252    use proc_macro2::TokenStream;
253    use sea_orm::{DbBackend, Schema};
254    use std::{
255        error::Error,
256        io::{self, BufRead, BufReader},
257    };
258
259    #[test]
260    fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
261        use crate::tests_cfg::duplicated_many_to_many_paths::*;
262        let schema = Schema::new(DbBackend::Postgres);
263
264        validate_compact_entities(
265            vec![
266                schema.create_table_from_entity(bills::Entity),
267                schema.create_table_from_entity(users::Entity),
268                schema.create_table_from_entity(users_saved_bills::Entity),
269                schema.create_table_from_entity(users_votes::Entity),
270            ],
271            vec![
272                (
273                    "bills",
274                    include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
275                ),
276                (
277                    "users",
278                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
279                ),
280                (
281                    "users_saved_bills",
282                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
283                ),
284                (
285                    "users_votes",
286                    include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
287                ),
288            ],
289        )
290    }
291
292    #[test]
293    fn many_to_many() -> Result<(), Box<dyn Error>> {
294        use crate::tests_cfg::many_to_many::*;
295        let schema = Schema::new(DbBackend::Postgres);
296
297        validate_compact_entities(
298            vec![
299                schema.create_table_from_entity(bills::Entity),
300                schema.create_table_from_entity(users::Entity),
301                schema.create_table_from_entity(users_votes::Entity),
302            ],
303            vec![
304                ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
305                ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
306                (
307                    "users_votes",
308                    include_str!("../tests_cfg/many_to_many/users_votes.rs"),
309                ),
310            ],
311        )
312    }
313
314    #[test]
315    fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
316        use crate::tests_cfg::many_to_many_multiple::*;
317        let schema = Schema::new(DbBackend::Postgres);
318
319        validate_compact_entities(
320            vec![
321                schema.create_table_from_entity(bills::Entity),
322                schema.create_table_from_entity(users::Entity),
323                schema.create_table_from_entity(users_votes::Entity),
324            ],
325            vec![
326                (
327                    "bills",
328                    include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
329                ),
330                (
331                    "users",
332                    include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
333                ),
334                (
335                    "users_votes",
336                    include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
337                ),
338            ],
339        )
340    }
341
342    #[test]
343    fn self_referencing() -> Result<(), Box<dyn Error>> {
344        use crate::tests_cfg::self_referencing::*;
345        let schema = Schema::new(DbBackend::Postgres);
346
347        validate_compact_entities(
348            vec![
349                schema.create_table_from_entity(bills::Entity),
350                schema.create_table_from_entity(users::Entity),
351            ],
352            vec![
353                (
354                    "bills",
355                    include_str!("../tests_cfg/self_referencing/bills.rs"),
356                ),
357                (
358                    "users",
359                    include_str!("../tests_cfg/self_referencing/users.rs"),
360                ),
361            ],
362        )
363    }
364
365    fn validate_compact_entities(
366        table_create_stmts: Vec<TableCreateStatement>,
367        files: Vec<(&str, &str)>,
368    ) -> Result<(), Box<dyn Error>> {
369        let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
370            .entities
371            .into_iter()
372            .map(|entity| (entity.table_name.clone(), entity))
373            .collect();
374
375        for (entity_name, file_content) in files {
376            let entity = entities
377                .get(entity_name)
378                .expect("Forget to add entity to the list");
379
380            assert_eq!(
381                parse_from_file(file_content.as_bytes())?.to_string(),
382                EntityWriter::gen_compact_code_blocks(
383                    entity,
384                    &crate::WithSerde::None,
385                    &crate::DateTimeCrate::Chrono,
386                    &None,
387                    false,
388                    false,
389                    &Default::default(),
390                    &Default::default(),
391                    false,
392                    true,
393                )
394                .into_iter()
395                .skip(1)
396                .fold(TokenStream::new(), |mut acc, tok| {
397                    acc.extend(tok);
398                    acc
399                })
400                .to_string()
401            );
402        }
403
404        Ok(())
405    }
406
407    fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
408    where
409        R: io::Read,
410    {
411        let mut reader = BufReader::new(inner);
412        let mut lines: Vec<String> = Vec::new();
413
414        reader.read_until(b';', &mut Vec::new())?;
415
416        let mut line = String::new();
417        while reader.read_line(&mut line)? > 0 {
418            lines.push(line.to_owned());
419            line.clear();
420        }
421        let content = lines.join("");
422        Ok(content.parse().unwrap())
423    }
424}