sea_orm_codegen/entity/
base_entity.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8    Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13    pub(crate) table_name: String,
14    pub(crate) columns: Vec<Column>,
15    pub(crate) relations: Vec<Relation>,
16    pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17    pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21    pub fn get_table_name_snake_case(&self) -> String {
22        self.table_name.to_snake_case()
23    }
24
25    pub fn get_table_name_camel_case(&self) -> String {
26        self.table_name.to_upper_camel_case()
27    }
28
29    pub fn get_table_name_snake_case_ident(&self) -> Ident {
30        format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31    }
32
33    pub fn get_table_name_camel_case_ident(&self) -> Ident {
34        format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35    }
36
37    pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38        self.columns
39            .iter()
40            .map(|col| col.get_name_snake_case())
41            .collect()
42    }
43
44    pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45        self.columns
46            .iter()
47            .map(|col| col.get_name_camel_case())
48            .collect()
49    }
50
51    pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52        self.columns
53            .clone()
54            .into_iter()
55            .map(|col| col.get_rs_type(date_time_crate))
56            .collect()
57    }
58
59    pub fn get_column_defs(&self) -> Vec<TokenStream> {
60        self.columns
61            .clone()
62            .into_iter()
63            .map(|col| col.get_def())
64            .collect()
65    }
66
67    pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68        self.primary_keys
69            .iter()
70            .map(|pk| pk.get_name_snake_case())
71            .collect()
72    }
73
74    pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75        self.primary_keys
76            .iter()
77            .map(|pk| pk.get_name_camel_case())
78            .collect()
79    }
80
81    pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82        self.relations
83            .iter()
84            .map(|rel| rel.get_module_name())
85            .collect()
86    }
87
88    pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89        self.relations
90            .iter()
91            .map(|rel| rel.get_enum_name())
92            .collect()
93    }
94
95    /// Used to generate the names for the `enum RelatedEntity` that is useful to the Seaography project
96    pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97        // 1st step get conjunct relations data
98        let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100        // 2nd step get reverse self relations data
101        let self_relations_reverse = self
102            .relations
103            .iter()
104            .filter(|rel| rel.self_referencing)
105            .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107        // 3rd step get normal relations data
108        self.get_relation_enum_name()
109            .into_iter()
110            .chain(self_relations_reverse)
111            .chain(conjunct_related_names)
112            .collect()
113    }
114
115    pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116        self.relations.iter().map(|rel| rel.get_def()).collect()
117    }
118
119    pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120        self.relations.iter().map(|rel| rel.get_attrs()).collect()
121    }
122
123    /// Trimmed get_related_entity_attrs down to just the entity module
124    pub fn get_related_entity_modules(&self) -> Vec<Ident> {
125        // 1st step get conjunct relations data
126        let conjunct_related_attrs = self
127            .conjunct_relations
128            .iter()
129            .map(|conj| conj.get_to_snake_case());
130
131        // helper function that generates attributes for `Relation` data
132        let produce_relation_attrs = |rel: &Relation, _reverse: bool| match rel.get_module_name() {
133            Some(module_name) => module_name,
134            None => format_ident!("self"),
135        };
136
137        // 2nd step get reverse self relations data
138        let self_relations_reverse_attrs = self
139            .relations
140            .iter()
141            .filter(|rel| rel.self_referencing)
142            .map(|rel| produce_relation_attrs(rel, true));
143
144        // 3rd step get normal relations data
145        self.relations
146            .iter()
147            .map(|rel| produce_relation_attrs(rel, false))
148            .chain(self_relations_reverse_attrs)
149            .chain(conjunct_related_attrs)
150            .collect()
151    }
152
153    /// Used to generate the attributes for the `enum RelatedEntity` that is useful to the Seaography project
154    pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
155        // 1st step get conjunct relations data
156        let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
157            let entity = format!("super::{}::Entity", conj.get_to_snake_case());
158
159            quote! {
160                #[sea_orm(
161                    entity = #entity
162                )]
163            }
164        });
165
166        // helper function that generates attributes for `Relation` data
167        let produce_relation_attrs = |rel: &Relation, reverse: bool| {
168            let entity = match rel.get_module_name() {
169                Some(module_name) => format!("super::{module_name}::Entity"),
170                None => String::from("Entity"),
171            };
172
173            if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
174                let def = if reverse {
175                    format!("Relation::{}.def().rev()", rel.get_enum_name())
176                } else {
177                    format!("Relation::{}.def()", rel.get_enum_name())
178                };
179
180                quote! {
181                    #[sea_orm(
182                        entity = #entity,
183                        def = #def
184                    )]
185                }
186            } else {
187                quote! {
188                    #[sea_orm(
189                        entity = #entity
190                    )]
191                }
192            }
193        };
194
195        // 2nd step get reverse self relations data
196        let self_relations_reverse_attrs = self
197            .relations
198            .iter()
199            .filter(|rel| rel.self_referencing)
200            .map(|rel| produce_relation_attrs(rel, true));
201
202        // 3rd step get normal relations data
203        self.relations
204            .iter()
205            .map(|rel| produce_relation_attrs(rel, false))
206            .chain(self_relations_reverse_attrs)
207            .chain(conjunct_related_attrs)
208            .collect()
209    }
210
211    pub fn get_primary_key_auto_increment(&self) -> Ident {
212        let auto_increment = self.columns.iter().any(|col| col.auto_increment);
213        format_ident!("{}", auto_increment)
214    }
215
216    pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
217        let types = self
218            .primary_keys
219            .iter()
220            .map(|primary_key| {
221                self.columns
222                    .iter()
223                    .find(|col| col.name.eq(&primary_key.name))
224                    .unwrap()
225                    .get_rs_type(date_time_crate)
226                    .to_string()
227            })
228            .collect::<Vec<_>>();
229        if !types.is_empty() {
230            let value_type = if types.len() > 1 {
231                vec!["(".to_owned(), types.join(", "), ")".to_owned()]
232            } else {
233                types
234            };
235            value_type.join("").parse().unwrap()
236        } else {
237            TokenStream::new()
238        }
239    }
240
241    pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
242        self.conjunct_relations
243            .iter()
244            .map(|con_rel| con_rel.get_via_snake_case())
245            .collect()
246    }
247
248    pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
249        self.conjunct_relations
250            .iter()
251            .map(|con_rel| con_rel.get_to_snake_case())
252            .collect()
253    }
254
255    pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
256        self.conjunct_relations
257            .iter()
258            .map(|con_rel| con_rel.get_to_upper_camel_case())
259            .collect()
260    }
261
262    pub fn get_eq_needed(&self) -> TokenStream {
263        fn is_floats(col_type: &ColumnType) -> bool {
264            match col_type {
265                ColumnType::Float | ColumnType::Double => true,
266                ColumnType::Array(col_type) => is_floats(col_type),
267                ColumnType::Vector(_) => true,
268                _ => false,
269            }
270        }
271        self.columns
272            .iter()
273            .find(|column| is_floats(&column.col_type))
274            // check if float or double exist.
275            // if exist, return nothing
276            .map_or(quote! {, Eq}, |_| quote! {})
277    }
278
279    pub fn get_column_serde_attributes(
280        &self,
281        serde_skip_deserializing_primary_key: bool,
282        serde_skip_hidden_column: bool,
283    ) -> Vec<TokenStream> {
284        self.columns
285            .iter()
286            .map(|col| {
287                let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
288                col.get_serde_attribute(
289                    is_primary_key,
290                    serde_skip_deserializing_primary_key,
291                    serde_skip_hidden_column,
292                )
293            })
294            .collect()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use quote::{format_ident, quote};
301    use sea_query::{ColumnType, ForeignKeyAction, StringLen};
302
303    use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
304
305    fn setup() -> Entity {
306        Entity {
307            table_name: "special_cake".to_owned(),
308            columns: vec![
309                Column {
310                    name: "id".to_owned(),
311                    col_type: ColumnType::Integer,
312                    auto_increment: false,
313                    not_null: false,
314                    unique: false,
315                },
316                Column {
317                    name: "name".to_owned(),
318                    col_type: ColumnType::String(StringLen::None),
319                    auto_increment: false,
320                    not_null: false,
321                    unique: false,
322                },
323            ],
324            relations: vec![
325                Relation {
326                    ref_table: "fruit".to_owned(),
327                    columns: vec!["id".to_owned()],
328                    ref_columns: vec!["cake_id".to_owned()],
329                    rel_type: RelationType::HasOne,
330                    on_delete: Some(ForeignKeyAction::Cascade),
331                    on_update: Some(ForeignKeyAction::Cascade),
332                    self_referencing: false,
333                    num_suffix: 0,
334                    impl_related: true,
335                },
336                Relation {
337                    ref_table: "filling".to_owned(),
338                    columns: vec!["id".to_owned()],
339                    ref_columns: vec!["cake_id".to_owned()],
340                    rel_type: RelationType::HasOne,
341                    on_delete: Some(ForeignKeyAction::Cascade),
342                    on_update: Some(ForeignKeyAction::Cascade),
343                    self_referencing: false,
344                    num_suffix: 0,
345                    impl_related: true,
346                },
347            ],
348            conjunct_relations: vec![],
349            primary_keys: vec![PrimaryKey {
350                name: "id".to_owned(),
351            }],
352        }
353    }
354
355    #[test]
356    fn test_get_table_name_snake_case() {
357        let entity = setup();
358
359        assert_eq!(
360            entity.get_table_name_snake_case(),
361            "special_cake".to_owned()
362        );
363    }
364
365    #[test]
366    fn test_get_table_name_camel_case() {
367        let entity = setup();
368
369        assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
370    }
371
372    #[test]
373    fn test_get_table_name_snake_case_ident() {
374        let entity = setup();
375
376        assert_eq!(
377            entity.get_table_name_snake_case_ident(),
378            format_ident!("{}", "special_cake")
379        );
380    }
381
382    #[test]
383    fn test_get_table_name_camel_case_ident() {
384        let entity = setup();
385
386        assert_eq!(
387            entity.get_table_name_camel_case_ident(),
388            format_ident!("{}", "SpecialCake")
389        );
390    }
391
392    #[test]
393    fn test_get_column_names_snake_case() {
394        let entity = setup();
395
396        for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
397            assert_eq!(elem, entity.columns[i].get_name_snake_case());
398        }
399    }
400
401    #[test]
402    fn test_get_column_names_camel_case() {
403        let entity = setup();
404
405        for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
406            assert_eq!(elem, entity.columns[i].get_name_camel_case());
407        }
408    }
409
410    #[test]
411    fn test_get_column_rs_types() {
412        let entity = setup();
413
414        for (i, elem) in entity
415            .get_column_rs_types(&DateTimeCrate::Chrono)
416            .into_iter()
417            .enumerate()
418        {
419            assert_eq!(
420                elem.to_string(),
421                entity.columns[i]
422                    .get_rs_type(&DateTimeCrate::Chrono)
423                    .to_string()
424            );
425        }
426    }
427
428    #[test]
429    fn test_get_column_defs() {
430        let entity = setup();
431
432        for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
433            assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
434        }
435    }
436
437    #[test]
438    fn test_get_primary_key_names_snake_case() {
439        let entity = setup();
440
441        for (i, elem) in entity
442            .get_primary_key_names_snake_case()
443            .into_iter()
444            .enumerate()
445        {
446            assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
447        }
448    }
449
450    #[test]
451    fn test_get_primary_key_names_camel_case() {
452        let entity = setup();
453
454        for (i, elem) in entity
455            .get_primary_key_names_camel_case()
456            .into_iter()
457            .enumerate()
458        {
459            assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
460        }
461    }
462
463    #[test]
464    fn test_get_relation_module_name() {
465        let entity = setup();
466
467        for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
468            assert_eq!(elem, entity.relations[i].get_module_name());
469        }
470    }
471
472    #[test]
473    fn test_get_relation_enum_name() {
474        let entity = setup();
475
476        for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
477            assert_eq!(elem, entity.relations[i].get_enum_name());
478        }
479    }
480
481    #[test]
482    fn test_get_relation_defs() {
483        let entity = setup();
484
485        for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
486            assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
487        }
488    }
489
490    #[test]
491    fn test_get_relation_attrs() {
492        let entity = setup();
493
494        for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
495            assert_eq!(
496                elem.to_string(),
497                entity.relations[i].get_attrs().to_string()
498            );
499        }
500    }
501
502    #[test]
503    fn test_get_primary_key_auto_increment() {
504        let mut entity = setup();
505
506        assert_eq!(
507            entity.get_primary_key_auto_increment(),
508            format_ident!("{}", false)
509        );
510
511        entity.columns[0].auto_increment = true;
512        assert_eq!(
513            entity.get_primary_key_auto_increment(),
514            format_ident!("{}", true)
515        );
516    }
517
518    #[test]
519    fn test_get_primary_key_rs_type() {
520        let entity = setup();
521
522        assert_eq!(
523            entity
524                .get_primary_key_rs_type(&DateTimeCrate::Chrono)
525                .to_string(),
526            entity.columns[0]
527                .get_rs_type(&DateTimeCrate::Chrono)
528                .to_string()
529        );
530    }
531
532    #[test]
533    fn test_get_conjunct_relations_via_snake_case() {
534        let entity = setup();
535
536        for (i, elem) in entity
537            .get_conjunct_relations_via_snake_case()
538            .into_iter()
539            .enumerate()
540        {
541            assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
542        }
543    }
544
545    #[test]
546    fn test_get_conjunct_relations_to_snake_case() {
547        let entity = setup();
548
549        for (i, elem) in entity
550            .get_conjunct_relations_to_snake_case()
551            .into_iter()
552            .enumerate()
553        {
554            assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
555        }
556    }
557
558    #[test]
559    fn test_get_conjunct_relations_to_upper_camel_case() {
560        let entity = setup();
561
562        for (i, elem) in entity
563            .get_conjunct_relations_to_upper_camel_case()
564            .into_iter()
565            .enumerate()
566        {
567            assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
568        }
569    }
570
571    #[test]
572    fn test_get_eq_needed() {
573        let entity = setup();
574        let expected = quote! {, Eq};
575
576        assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
577    }
578}