sea_orm_codegen/entity/
base_entity.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8    util::escape_rust_keyword, Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13    pub(crate) table_name: String,
14    pub(crate) columns: Vec<Column>,
15    pub(crate) relations: Vec<Relation>,
16    pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17    pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21    pub fn get_table_name_snake_case(&self) -> String {
22        self.table_name.to_snake_case()
23    }
24
25    pub fn get_table_name_camel_case(&self) -> String {
26        self.table_name.to_upper_camel_case()
27    }
28
29    pub fn get_table_name_snake_case_ident(&self) -> Ident {
30        format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31    }
32
33    pub fn get_table_name_camel_case_ident(&self) -> Ident {
34        format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35    }
36
37    pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38        self.columns
39            .iter()
40            .map(|col| col.get_name_snake_case())
41            .collect()
42    }
43
44    pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45        self.columns
46            .iter()
47            .map(|col| col.get_name_camel_case())
48            .collect()
49    }
50
51    pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52        self.columns
53            .clone()
54            .into_iter()
55            .map(|col| col.get_rs_type(date_time_crate))
56            .collect()
57    }
58
59    pub fn get_column_defs(&self) -> Vec<TokenStream> {
60        self.columns
61            .clone()
62            .into_iter()
63            .map(|col| col.get_def())
64            .collect()
65    }
66
67    pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68        self.primary_keys
69            .iter()
70            .map(|pk| pk.get_name_snake_case())
71            .collect()
72    }
73
74    pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75        self.primary_keys
76            .iter()
77            .map(|pk| pk.get_name_camel_case())
78            .collect()
79    }
80
81    pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82        self.relations
83            .iter()
84            .map(|rel| rel.get_module_name())
85            .collect()
86    }
87
88    pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89        self.relations
90            .iter()
91            .map(|rel| rel.get_enum_name())
92            .collect()
93    }
94
95    /// Used to generate the names for the `enum RelatedEntity` that is useful to the Seaography project
96    pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97        // 1st step get conjunct relations data
98        let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100        // 2nd step get reverse self relations data
101        let self_relations_reverse = self
102            .relations
103            .iter()
104            .filter(|rel| rel.self_referencing)
105            .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107        // 3rd step get normal relations data
108        self.get_relation_enum_name()
109            .into_iter()
110            .chain(self_relations_reverse)
111            .chain(conjunct_related_names)
112            .collect()
113    }
114
115    pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116        self.relations.iter().map(|rel| rel.get_def()).collect()
117    }
118
119    pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120        self.relations.iter().map(|rel| rel.get_attrs()).collect()
121    }
122
123    /// Used to generate the attributes for the `enum RelatedEntity` that is useful to the Seaography project
124    pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
125        // 1st step get conjunct relations data
126        let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
127            let entity = format!("super::{}::Entity", conj.get_to_snake_case());
128
129            quote! {
130                #[sea_orm(
131                    entity = #entity
132                )]
133            }
134        });
135
136        // helper function that generates attributes for `Relation` data
137        let produce_relation_attrs = |rel: &Relation, reverse: bool| {
138            let entity = match rel.get_module_name() {
139                Some(module_name) => format!("super::{}::Entity", module_name),
140                None => String::from("Entity"),
141            };
142
143            if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
144                let def = if reverse {
145                    format!("Relation::{}.def().rev()", rel.get_enum_name())
146                } else {
147                    format!("Relation::{}.def()", rel.get_enum_name())
148                };
149
150                quote! {
151                    #[sea_orm(
152                        entity = #entity,
153                        def = #def
154                    )]
155                }
156            } else {
157                quote! {
158                    #[sea_orm(
159                        entity = #entity
160                    )]
161                }
162            }
163        };
164
165        // 2nd step get reverse self relations data
166        let self_relations_reverse_attrs = self
167            .relations
168            .iter()
169            .filter(|rel| rel.self_referencing)
170            .map(|rel| produce_relation_attrs(rel, true));
171
172        // 3rd step get normal relations data
173        self.relations
174            .iter()
175            .map(|rel| produce_relation_attrs(rel, false))
176            .chain(self_relations_reverse_attrs)
177            .chain(conjunct_related_attrs)
178            .collect()
179    }
180
181    pub fn get_primary_key_auto_increment(&self) -> Ident {
182        let auto_increment = self.columns.iter().any(|col| col.auto_increment);
183        format_ident!("{}", auto_increment)
184    }
185
186    pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
187        let types = self
188            .primary_keys
189            .iter()
190            .map(|primary_key| {
191                self.columns
192                    .iter()
193                    .find(|col| col.name.eq(&primary_key.name))
194                    .unwrap()
195                    .get_rs_type(date_time_crate)
196                    .to_string()
197            })
198            .collect::<Vec<_>>();
199        if !types.is_empty() {
200            let value_type = if types.len() > 1 {
201                vec!["(".to_owned(), types.join(", "), ")".to_owned()]
202            } else {
203                types
204            };
205            value_type.join("").parse().unwrap()
206        } else {
207            TokenStream::new()
208        }
209    }
210
211    pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
212        self.conjunct_relations
213            .iter()
214            .map(|con_rel| con_rel.get_via_snake_case())
215            .collect()
216    }
217
218    pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
219        self.conjunct_relations
220            .iter()
221            .map(|con_rel| con_rel.get_to_snake_case())
222            .collect()
223    }
224
225    pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
226        self.conjunct_relations
227            .iter()
228            .map(|con_rel| con_rel.get_to_upper_camel_case())
229            .collect()
230    }
231
232    pub fn get_eq_needed(&self) -> TokenStream {
233        fn is_floats(col_type: &ColumnType) -> bool {
234            match col_type {
235                ColumnType::Float | ColumnType::Double => true,
236                ColumnType::Array(col_type) => is_floats(col_type),
237                _ => false,
238            }
239        }
240        self.columns
241            .iter()
242            .find(|column| is_floats(&column.col_type))
243            // check if float or double exist.
244            // if exist, return nothing
245            .map_or(quote! {, Eq}, |_| quote! {})
246    }
247
248    pub fn get_column_serde_attributes(
249        &self,
250        serde_skip_deserializing_primary_key: bool,
251        serde_skip_hidden_column: bool,
252    ) -> Vec<TokenStream> {
253        self.columns
254            .iter()
255            .map(|col| {
256                let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
257                col.get_serde_attribute(
258                    is_primary_key,
259                    serde_skip_deserializing_primary_key,
260                    serde_skip_hidden_column,
261                )
262            })
263            .collect()
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use quote::{format_ident, quote};
270    use sea_query::{ColumnType, ForeignKeyAction, StringLen};
271
272    use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
273
274    fn setup() -> Entity {
275        Entity {
276            table_name: "special_cake".to_owned(),
277            columns: vec![
278                Column {
279                    name: "id".to_owned(),
280                    col_type: ColumnType::Integer,
281                    auto_increment: false,
282                    not_null: false,
283                    unique: false,
284                },
285                Column {
286                    name: "name".to_owned(),
287                    col_type: ColumnType::String(StringLen::None),
288                    auto_increment: false,
289                    not_null: false,
290                    unique: false,
291                },
292            ],
293            relations: vec![
294                Relation {
295                    ref_table: "fruit".to_owned(),
296                    columns: vec!["id".to_owned()],
297                    ref_columns: vec!["cake_id".to_owned()],
298                    rel_type: RelationType::HasOne,
299                    on_delete: Some(ForeignKeyAction::Cascade),
300                    on_update: Some(ForeignKeyAction::Cascade),
301                    self_referencing: false,
302                    num_suffix: 0,
303                    impl_related: true,
304                },
305                Relation {
306                    ref_table: "filling".to_owned(),
307                    columns: vec!["id".to_owned()],
308                    ref_columns: vec!["cake_id".to_owned()],
309                    rel_type: RelationType::HasOne,
310                    on_delete: Some(ForeignKeyAction::Cascade),
311                    on_update: Some(ForeignKeyAction::Cascade),
312                    self_referencing: false,
313                    num_suffix: 0,
314                    impl_related: true,
315                },
316            ],
317            conjunct_relations: vec![],
318            primary_keys: vec![PrimaryKey {
319                name: "id".to_owned(),
320            }],
321        }
322    }
323
324    #[test]
325    fn test_get_table_name_snake_case() {
326        let entity = setup();
327
328        assert_eq!(
329            entity.get_table_name_snake_case(),
330            "special_cake".to_owned()
331        );
332    }
333
334    #[test]
335    fn test_get_table_name_camel_case() {
336        let entity = setup();
337
338        assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
339    }
340
341    #[test]
342    fn test_get_table_name_snake_case_ident() {
343        let entity = setup();
344
345        assert_eq!(
346            entity.get_table_name_snake_case_ident(),
347            format_ident!("{}", "special_cake")
348        );
349    }
350
351    #[test]
352    fn test_get_table_name_camel_case_ident() {
353        let entity = setup();
354
355        assert_eq!(
356            entity.get_table_name_camel_case_ident(),
357            format_ident!("{}", "SpecialCake")
358        );
359    }
360
361    #[test]
362    fn test_get_column_names_snake_case() {
363        let entity = setup();
364
365        for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
366            assert_eq!(elem, entity.columns[i].get_name_snake_case());
367        }
368    }
369
370    #[test]
371    fn test_get_column_names_camel_case() {
372        let entity = setup();
373
374        for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
375            assert_eq!(elem, entity.columns[i].get_name_camel_case());
376        }
377    }
378
379    #[test]
380    fn test_get_column_rs_types() {
381        let entity = setup();
382
383        for (i, elem) in entity
384            .get_column_rs_types(&DateTimeCrate::Chrono)
385            .into_iter()
386            .enumerate()
387        {
388            assert_eq!(
389                elem.to_string(),
390                entity.columns[i]
391                    .get_rs_type(&DateTimeCrate::Chrono)
392                    .to_string()
393            );
394        }
395    }
396
397    #[test]
398    fn test_get_column_defs() {
399        let entity = setup();
400
401        for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
402            assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
403        }
404    }
405
406    #[test]
407    fn test_get_primary_key_names_snake_case() {
408        let entity = setup();
409
410        for (i, elem) in entity
411            .get_primary_key_names_snake_case()
412            .into_iter()
413            .enumerate()
414        {
415            assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
416        }
417    }
418
419    #[test]
420    fn test_get_primary_key_names_camel_case() {
421        let entity = setup();
422
423        for (i, elem) in entity
424            .get_primary_key_names_camel_case()
425            .into_iter()
426            .enumerate()
427        {
428            assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
429        }
430    }
431
432    #[test]
433    fn test_get_relation_module_name() {
434        let entity = setup();
435
436        for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
437            assert_eq!(elem, entity.relations[i].get_module_name());
438        }
439    }
440
441    #[test]
442    fn test_get_relation_enum_name() {
443        let entity = setup();
444
445        for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
446            assert_eq!(elem, entity.relations[i].get_enum_name());
447        }
448    }
449
450    #[test]
451    fn test_get_relation_defs() {
452        let entity = setup();
453
454        for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
455            assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
456        }
457    }
458
459    #[test]
460    fn test_get_relation_attrs() {
461        let entity = setup();
462
463        for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
464            assert_eq!(
465                elem.to_string(),
466                entity.relations[i].get_attrs().to_string()
467            );
468        }
469    }
470
471    #[test]
472    fn test_get_primary_key_auto_increment() {
473        let mut entity = setup();
474
475        assert_eq!(
476            entity.get_primary_key_auto_increment(),
477            format_ident!("{}", false)
478        );
479
480        entity.columns[0].auto_increment = true;
481        assert_eq!(
482            entity.get_primary_key_auto_increment(),
483            format_ident!("{}", true)
484        );
485    }
486
487    #[test]
488    fn test_get_primary_key_rs_type() {
489        let entity = setup();
490
491        assert_eq!(
492            entity
493                .get_primary_key_rs_type(&DateTimeCrate::Chrono)
494                .to_string(),
495            entity.columns[0]
496                .get_rs_type(&DateTimeCrate::Chrono)
497                .to_string()
498        );
499    }
500
501    #[test]
502    fn test_get_conjunct_relations_via_snake_case() {
503        let entity = setup();
504
505        for (i, elem) in entity
506            .get_conjunct_relations_via_snake_case()
507            .into_iter()
508            .enumerate()
509        {
510            assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
511        }
512    }
513
514    #[test]
515    fn test_get_conjunct_relations_to_snake_case() {
516        let entity = setup();
517
518        for (i, elem) in entity
519            .get_conjunct_relations_to_snake_case()
520            .into_iter()
521            .enumerate()
522        {
523            assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
524        }
525    }
526
527    #[test]
528    fn test_get_conjunct_relations_to_upper_camel_case() {
529        let entity = setup();
530
531        for (i, elem) in entity
532            .get_conjunct_relations_to_upper_camel_case()
533            .into_iter()
534            .enumerate()
535        {
536            assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
537        }
538    }
539
540    #[test]
541    fn test_get_eq_needed() {
542        let entity = setup();
543        let expected = quote! {, Eq};
544
545        assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
546    }
547}