sea_orm_codegen/entity/
base_entity.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8    Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13    pub(crate) table_name: String,
14    pub(crate) columns: Vec<Column>,
15    pub(crate) relations: Vec<Relation>,
16    pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17    pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21    pub fn get_table_name_snake_case(&self) -> String {
22        self.table_name.to_snake_case()
23    }
24
25    pub fn get_table_name_camel_case(&self) -> String {
26        self.table_name.to_upper_camel_case()
27    }
28
29    pub fn get_table_name_snake_case_ident(&self) -> Ident {
30        format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31    }
32
33    pub fn get_table_name_camel_case_ident(&self) -> Ident {
34        format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35    }
36
37    pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38        self.columns
39            .iter()
40            .map(|col| col.get_name_snake_case())
41            .collect()
42    }
43
44    pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45        self.columns
46            .iter()
47            .map(|col| col.get_name_camel_case())
48            .collect()
49    }
50
51    pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52        self.columns
53            .clone()
54            .into_iter()
55            .map(|col| col.get_rs_type(date_time_crate))
56            .collect()
57    }
58
59    pub fn get_column_defs(&self) -> Vec<TokenStream> {
60        self.columns
61            .clone()
62            .into_iter()
63            .map(|col| col.get_def())
64            .collect()
65    }
66
67    pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68        self.primary_keys
69            .iter()
70            .map(|pk| pk.get_name_snake_case())
71            .collect()
72    }
73
74    pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75        self.primary_keys
76            .iter()
77            .map(|pk| pk.get_name_camel_case())
78            .collect()
79    }
80
81    pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82        self.relations
83            .iter()
84            .map(|rel| rel.get_module_name())
85            .collect()
86    }
87
88    pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89        self.relations
90            .iter()
91            .map(|rel| rel.get_enum_name())
92            .collect()
93    }
94
95    /// Used to generate the names for the `enum RelatedEntity` that is useful to the Seaography project
96    pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97        // 1st step get conjunct relations data
98        let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100        // 2nd step get reverse self relations data
101        let self_relations_reverse = self
102            .relations
103            .iter()
104            .filter(|rel| rel.self_referencing)
105            .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107        // 3rd step get normal relations data
108        self.get_relation_enum_name()
109            .into_iter()
110            .chain(self_relations_reverse)
111            .chain(conjunct_related_names)
112            .collect()
113    }
114
115    pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116        self.relations.iter().map(|rel| rel.get_def()).collect()
117    }
118
119    pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120        self.relations.iter().map(|rel| rel.get_attrs()).collect()
121    }
122
123    /// Used to generate the attributes for the `enum RelatedEntity` that is useful to the Seaography project
124    pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
125        // 1st step get conjunct relations data
126        let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
127            let entity = format!("super::{}::Entity", conj.get_to_snake_case());
128
129            quote! {
130                #[sea_orm(
131                    entity = #entity
132                )]
133            }
134        });
135
136        // helper function that generates attributes for `Relation` data
137        let produce_relation_attrs = |rel: &Relation, reverse: bool| {
138            let entity = match rel.get_module_name() {
139                Some(module_name) => format!("super::{}::Entity", module_name),
140                None => String::from("Entity"),
141            };
142
143            if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
144                let def = if reverse {
145                    format!("Relation::{}.def().rev()", rel.get_enum_name())
146                } else {
147                    format!("Relation::{}.def()", rel.get_enum_name())
148                };
149
150                quote! {
151                    #[sea_orm(
152                        entity = #entity,
153                        def = #def
154                    )]
155                }
156            } else {
157                quote! {
158                    #[sea_orm(
159                        entity = #entity
160                    )]
161                }
162            }
163        };
164
165        // 2nd step get reverse self relations data
166        let self_relations_reverse_attrs = self
167            .relations
168            .iter()
169            .filter(|rel| rel.self_referencing)
170            .map(|rel| produce_relation_attrs(rel, true));
171
172        // 3rd step get normal relations data
173        self.relations
174            .iter()
175            .map(|rel| produce_relation_attrs(rel, false))
176            .chain(self_relations_reverse_attrs)
177            .chain(conjunct_related_attrs)
178            .collect()
179    }
180
181    pub fn get_primary_key_auto_increment(&self) -> Ident {
182        let auto_increment = self.columns.iter().any(|col| col.auto_increment);
183        format_ident!("{}", auto_increment)
184    }
185
186    pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
187        let types = self
188            .primary_keys
189            .iter()
190            .map(|primary_key| {
191                self.columns
192                    .iter()
193                    .find(|col| col.name.eq(&primary_key.name))
194                    .unwrap()
195                    .get_rs_type(date_time_crate)
196                    .to_string()
197            })
198            .collect::<Vec<_>>();
199        if !types.is_empty() {
200            let value_type = if types.len() > 1 {
201                vec!["(".to_owned(), types.join(", "), ")".to_owned()]
202            } else {
203                types
204            };
205            value_type.join("").parse().unwrap()
206        } else {
207            TokenStream::new()
208        }
209    }
210
211    pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
212        self.conjunct_relations
213            .iter()
214            .map(|con_rel| con_rel.get_via_snake_case())
215            .collect()
216    }
217
218    pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
219        self.conjunct_relations
220            .iter()
221            .map(|con_rel| con_rel.get_to_snake_case())
222            .collect()
223    }
224
225    pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
226        self.conjunct_relations
227            .iter()
228            .map(|con_rel| con_rel.get_to_upper_camel_case())
229            .collect()
230    }
231
232    pub fn get_eq_needed(&self) -> TokenStream {
233        fn is_floats(col_type: &ColumnType) -> bool {
234            match col_type {
235                ColumnType::Float | ColumnType::Double => true,
236                ColumnType::Array(col_type) => is_floats(col_type),
237                ColumnType::Vector(_) => true,
238                _ => false,
239            }
240        }
241        self.columns
242            .iter()
243            .find(|column| is_floats(&column.col_type))
244            // check if float or double exist.
245            // if exist, return nothing
246            .map_or(quote! {, Eq}, |_| quote! {})
247    }
248
249    pub fn get_column_serde_attributes(
250        &self,
251        serde_skip_deserializing_primary_key: bool,
252        serde_skip_hidden_column: bool,
253    ) -> Vec<TokenStream> {
254        self.columns
255            .iter()
256            .map(|col| {
257                let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
258                col.get_serde_attribute(
259                    is_primary_key,
260                    serde_skip_deserializing_primary_key,
261                    serde_skip_hidden_column,
262                )
263            })
264            .collect()
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use quote::{format_ident, quote};
271    use sea_query::{ColumnType, ForeignKeyAction, StringLen};
272
273    use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
274
275    fn setup() -> Entity {
276        Entity {
277            table_name: "special_cake".to_owned(),
278            columns: vec![
279                Column {
280                    name: "id".to_owned(),
281                    col_type: ColumnType::Integer,
282                    auto_increment: false,
283                    not_null: false,
284                    unique: false,
285                },
286                Column {
287                    name: "name".to_owned(),
288                    col_type: ColumnType::String(StringLen::None),
289                    auto_increment: false,
290                    not_null: false,
291                    unique: false,
292                },
293            ],
294            relations: vec![
295                Relation {
296                    ref_table: "fruit".to_owned(),
297                    columns: vec!["id".to_owned()],
298                    ref_columns: vec!["cake_id".to_owned()],
299                    rel_type: RelationType::HasOne,
300                    on_delete: Some(ForeignKeyAction::Cascade),
301                    on_update: Some(ForeignKeyAction::Cascade),
302                    self_referencing: false,
303                    num_suffix: 0,
304                    impl_related: true,
305                },
306                Relation {
307                    ref_table: "filling".to_owned(),
308                    columns: vec!["id".to_owned()],
309                    ref_columns: vec!["cake_id".to_owned()],
310                    rel_type: RelationType::HasOne,
311                    on_delete: Some(ForeignKeyAction::Cascade),
312                    on_update: Some(ForeignKeyAction::Cascade),
313                    self_referencing: false,
314                    num_suffix: 0,
315                    impl_related: true,
316                },
317            ],
318            conjunct_relations: vec![],
319            primary_keys: vec![PrimaryKey {
320                name: "id".to_owned(),
321            }],
322        }
323    }
324
325    #[test]
326    fn test_get_table_name_snake_case() {
327        let entity = setup();
328
329        assert_eq!(
330            entity.get_table_name_snake_case(),
331            "special_cake".to_owned()
332        );
333    }
334
335    #[test]
336    fn test_get_table_name_camel_case() {
337        let entity = setup();
338
339        assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
340    }
341
342    #[test]
343    fn test_get_table_name_snake_case_ident() {
344        let entity = setup();
345
346        assert_eq!(
347            entity.get_table_name_snake_case_ident(),
348            format_ident!("{}", "special_cake")
349        );
350    }
351
352    #[test]
353    fn test_get_table_name_camel_case_ident() {
354        let entity = setup();
355
356        assert_eq!(
357            entity.get_table_name_camel_case_ident(),
358            format_ident!("{}", "SpecialCake")
359        );
360    }
361
362    #[test]
363    fn test_get_column_names_snake_case() {
364        let entity = setup();
365
366        for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
367            assert_eq!(elem, entity.columns[i].get_name_snake_case());
368        }
369    }
370
371    #[test]
372    fn test_get_column_names_camel_case() {
373        let entity = setup();
374
375        for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
376            assert_eq!(elem, entity.columns[i].get_name_camel_case());
377        }
378    }
379
380    #[test]
381    fn test_get_column_rs_types() {
382        let entity = setup();
383
384        for (i, elem) in entity
385            .get_column_rs_types(&DateTimeCrate::Chrono)
386            .into_iter()
387            .enumerate()
388        {
389            assert_eq!(
390                elem.to_string(),
391                entity.columns[i]
392                    .get_rs_type(&DateTimeCrate::Chrono)
393                    .to_string()
394            );
395        }
396    }
397
398    #[test]
399    fn test_get_column_defs() {
400        let entity = setup();
401
402        for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
403            assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
404        }
405    }
406
407    #[test]
408    fn test_get_primary_key_names_snake_case() {
409        let entity = setup();
410
411        for (i, elem) in entity
412            .get_primary_key_names_snake_case()
413            .into_iter()
414            .enumerate()
415        {
416            assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
417        }
418    }
419
420    #[test]
421    fn test_get_primary_key_names_camel_case() {
422        let entity = setup();
423
424        for (i, elem) in entity
425            .get_primary_key_names_camel_case()
426            .into_iter()
427            .enumerate()
428        {
429            assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
430        }
431    }
432
433    #[test]
434    fn test_get_relation_module_name() {
435        let entity = setup();
436
437        for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
438            assert_eq!(elem, entity.relations[i].get_module_name());
439        }
440    }
441
442    #[test]
443    fn test_get_relation_enum_name() {
444        let entity = setup();
445
446        for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
447            assert_eq!(elem, entity.relations[i].get_enum_name());
448        }
449    }
450
451    #[test]
452    fn test_get_relation_defs() {
453        let entity = setup();
454
455        for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
456            assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
457        }
458    }
459
460    #[test]
461    fn test_get_relation_attrs() {
462        let entity = setup();
463
464        for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
465            assert_eq!(
466                elem.to_string(),
467                entity.relations[i].get_attrs().to_string()
468            );
469        }
470    }
471
472    #[test]
473    fn test_get_primary_key_auto_increment() {
474        let mut entity = setup();
475
476        assert_eq!(
477            entity.get_primary_key_auto_increment(),
478            format_ident!("{}", false)
479        );
480
481        entity.columns[0].auto_increment = true;
482        assert_eq!(
483            entity.get_primary_key_auto_increment(),
484            format_ident!("{}", true)
485        );
486    }
487
488    #[test]
489    fn test_get_primary_key_rs_type() {
490        let entity = setup();
491
492        assert_eq!(
493            entity
494                .get_primary_key_rs_type(&DateTimeCrate::Chrono)
495                .to_string(),
496            entity.columns[0]
497                .get_rs_type(&DateTimeCrate::Chrono)
498                .to_string()
499        );
500    }
501
502    #[test]
503    fn test_get_conjunct_relations_via_snake_case() {
504        let entity = setup();
505
506        for (i, elem) in entity
507            .get_conjunct_relations_via_snake_case()
508            .into_iter()
509            .enumerate()
510        {
511            assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
512        }
513    }
514
515    #[test]
516    fn test_get_conjunct_relations_to_snake_case() {
517        let entity = setup();
518
519        for (i, elem) in entity
520            .get_conjunct_relations_to_snake_case()
521            .into_iter()
522            .enumerate()
523        {
524            assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
525        }
526    }
527
528    #[test]
529    fn test_get_conjunct_relations_to_upper_camel_case() {
530        let entity = setup();
531
532        for (i, elem) in entity
533            .get_conjunct_relations_to_upper_camel_case()
534            .into_iter()
535            .enumerate()
536        {
537            assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
538        }
539    }
540
541    #[test]
542    fn test_get_eq_needed() {
543        let entity = setup();
544        let expected = quote! {, Eq};
545
546        assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
547    }
548}