sea_orm_codegen/entity/
base_entity.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8    Column, ColumnOption, ConjunctRelation, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13    pub(crate) table_name: String,
14    pub(crate) columns: Vec<Column>,
15    pub(crate) relations: Vec<Relation>,
16    pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17    pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21    pub fn get_table_name_snake_case(&self) -> String {
22        self.table_name.to_snake_case()
23    }
24
25    pub fn get_table_name_camel_case(&self) -> String {
26        self.table_name.to_upper_camel_case()
27    }
28
29    pub fn get_table_name_snake_case_ident(&self) -> Ident {
30        format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31    }
32
33    pub fn get_table_name_camel_case_ident(&self) -> Ident {
34        format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35    }
36
37    pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38        self.columns
39            .iter()
40            .map(|col| col.get_name_snake_case())
41            .collect()
42    }
43
44    pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45        self.columns
46            .iter()
47            .map(|col| col.get_name_camel_case())
48            .collect()
49    }
50
51    pub fn get_column_rs_types(&self, opt: &ColumnOption) -> Vec<TokenStream> {
52        self.columns
53            .clone()
54            .into_iter()
55            .map(|col| col.get_rs_type(opt))
56            .collect()
57    }
58
59    pub fn get_column_defs(&self) -> Vec<TokenStream> {
60        self.columns
61            .clone()
62            .into_iter()
63            .map(|col| col.get_def())
64            .collect()
65    }
66
67    pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68        self.primary_keys
69            .iter()
70            .map(|pk| pk.get_name_snake_case())
71            .collect()
72    }
73
74    pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75        self.primary_keys
76            .iter()
77            .map(|pk| pk.get_name_camel_case())
78            .collect()
79    }
80
81    pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82        self.relations
83            .iter()
84            .map(|rel| rel.get_module_name())
85            .collect()
86    }
87
88    pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89        self.relations
90            .iter()
91            .map(|rel| rel.get_enum_name())
92            .collect()
93    }
94
95    /// Used to generate the names for the `enum RelatedEntity` that is useful to the Seaography project
96    pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97        // 1st step get conjunct relations data
98        let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100        // 2nd step get reverse self relations data
101        let self_relations_reverse = self
102            .relations
103            .iter()
104            .filter(|rel| rel.self_referencing)
105            .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107        // 3rd step get normal relations data
108        self.get_relation_enum_name()
109            .into_iter()
110            .chain(self_relations_reverse)
111            .chain(conjunct_related_names)
112            .collect()
113    }
114
115    pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116        self.relations.iter().map(|rel| rel.get_def()).collect()
117    }
118
119    pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120        self.relations.iter().map(|rel| rel.get_attrs()).collect()
121    }
122
123    /// Trimmed get_related_entity_attrs down to just the entity module
124    pub fn get_related_entity_modules(&self) -> Vec<Ident> {
125        // 1st step get conjunct relations data
126        let conjunct_related_attrs = self
127            .conjunct_relations
128            .iter()
129            .map(|conj| conj.get_to_snake_case());
130
131        // helper function that generates attributes for `Relation` data
132        let produce_relation_attrs = |rel: &Relation, _reverse: bool| match rel.get_module_name() {
133            Some(module_name) => module_name,
134            None => format_ident!("self"),
135        };
136
137        // 2nd step get reverse self relations data
138        let self_relations_reverse_attrs = self
139            .relations
140            .iter()
141            .filter(|rel| rel.self_referencing)
142            .map(|rel| produce_relation_attrs(rel, true));
143
144        // 3rd step get normal relations data
145        self.relations
146            .iter()
147            .map(|rel| produce_relation_attrs(rel, false))
148            .chain(self_relations_reverse_attrs)
149            .chain(conjunct_related_attrs)
150            .collect()
151    }
152
153    /// Used to generate the attributes for the `enum RelatedEntity` that is useful to the Seaography project
154    pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
155        // 1st step get conjunct relations data
156        let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
157            let entity = format!("super::{}::Entity", conj.get_to_snake_case());
158
159            quote! {
160                #[sea_orm(
161                    entity = #entity
162                )]
163            }
164        });
165
166        // helper function that generates attributes for `Relation` data
167        let produce_relation_attrs = |rel: &Relation, reverse: bool| {
168            let entity = match rel.get_module_name() {
169                Some(module_name) => format!("super::{module_name}::Entity"),
170                None => String::from("Entity"),
171            };
172
173            if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
174                let def = if reverse {
175                    format!("Relation::{}.def().rev()", rel.get_enum_name())
176                } else {
177                    format!("Relation::{}.def()", rel.get_enum_name())
178                };
179
180                quote! {
181                    #[sea_orm(
182                        entity = #entity,
183                        def = #def
184                    )]
185                }
186            } else {
187                quote! {
188                    #[sea_orm(
189                        entity = #entity
190                    )]
191                }
192            }
193        };
194
195        // 2nd step get reverse self relations data
196        let self_relations_reverse_attrs = self
197            .relations
198            .iter()
199            .filter(|rel| rel.self_referencing)
200            .map(|rel| produce_relation_attrs(rel, true));
201
202        // 3rd step get normal relations data
203        self.relations
204            .iter()
205            .map(|rel| produce_relation_attrs(rel, false))
206            .chain(self_relations_reverse_attrs)
207            .chain(conjunct_related_attrs)
208            .collect()
209    }
210
211    pub fn get_primary_key_auto_increment(&self) -> Ident {
212        let auto_increment = self.columns.iter().any(|col| col.auto_increment);
213        format_ident!("{}", auto_increment)
214    }
215
216    pub fn get_primary_key_rs_type(&self, opt: &ColumnOption) -> TokenStream {
217        let types = self
218            .primary_keys
219            .iter()
220            .map(|primary_key| {
221                self.columns
222                    .iter()
223                    .find(|col| col.name.eq(&primary_key.name))
224                    .unwrap()
225                    .get_rs_type(opt)
226                    .to_string()
227            })
228            .collect::<Vec<_>>();
229        if !types.is_empty() {
230            let value_type = if types.len() > 1 {
231                vec!["(".to_owned(), types.join(", "), ")".to_owned()]
232            } else {
233                types
234            };
235            value_type.join("").parse().unwrap()
236        } else {
237            TokenStream::new()
238        }
239    }
240
241    pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
242        self.conjunct_relations
243            .iter()
244            .map(|con_rel| con_rel.get_via_snake_case())
245            .collect()
246    }
247
248    pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
249        self.conjunct_relations
250            .iter()
251            .map(|con_rel| con_rel.get_to_snake_case())
252            .collect()
253    }
254
255    pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
256        self.conjunct_relations
257            .iter()
258            .map(|con_rel| con_rel.get_to_upper_camel_case())
259            .collect()
260    }
261
262    pub fn get_eq_needed(&self) -> TokenStream {
263        fn is_floats(col_type: &ColumnType) -> bool {
264            match col_type {
265                ColumnType::Float | ColumnType::Double => true,
266                ColumnType::Array(col_type) => is_floats(col_type),
267                ColumnType::Vector(_) => true,
268                _ => false,
269            }
270        }
271        self.columns
272            .iter()
273            .find(|column| is_floats(&column.col_type))
274            // check if float or double exist.
275            // if exist, return nothing
276            .map_or(quote! {, Eq}, |_| quote! {})
277    }
278
279    pub fn get_column_serde_attributes(
280        &self,
281        serde_skip_deserializing_primary_key: bool,
282        serde_skip_hidden_column: bool,
283    ) -> Vec<TokenStream> {
284        self.columns
285            .iter()
286            .map(|col| {
287                let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
288                col.get_serde_attribute(
289                    is_primary_key,
290                    serde_skip_deserializing_primary_key,
291                    serde_skip_hidden_column,
292                )
293            })
294            .collect()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use quote::{format_ident, quote};
301    use sea_query::{ColumnType, ForeignKeyAction, StringLen};
302
303    use crate::{Column, ColumnOption, Entity, PrimaryKey, Relation, RelationType};
304
305    fn setup() -> Entity {
306        Entity {
307            table_name: "special_cake".to_owned(),
308            columns: vec![
309                Column {
310                    name: "id".to_owned(),
311                    col_type: ColumnType::Integer,
312                    auto_increment: false,
313                    not_null: false,
314                    unique: false,
315                    unique_key: None,
316                },
317                Column {
318                    name: "name".to_owned(),
319                    col_type: ColumnType::String(StringLen::None),
320                    auto_increment: false,
321                    not_null: false,
322                    unique: false,
323                    unique_key: None,
324                },
325            ],
326            relations: vec![
327                Relation {
328                    ref_table: "fruit".to_owned(),
329                    columns: vec!["id".to_owned()],
330                    ref_columns: vec!["cake_id".to_owned()],
331                    rel_type: RelationType::HasOne,
332                    on_delete: Some(ForeignKeyAction::Cascade),
333                    on_update: Some(ForeignKeyAction::Cascade),
334                    self_referencing: false,
335                    num_suffix: 0,
336                    impl_related: true,
337                },
338                Relation {
339                    ref_table: "filling".to_owned(),
340                    columns: vec!["id".to_owned()],
341                    ref_columns: vec!["cake_id".to_owned()],
342                    rel_type: RelationType::HasOne,
343                    on_delete: Some(ForeignKeyAction::Cascade),
344                    on_update: Some(ForeignKeyAction::Cascade),
345                    self_referencing: false,
346                    num_suffix: 0,
347                    impl_related: true,
348                },
349            ],
350            conjunct_relations: vec![],
351            primary_keys: vec![PrimaryKey {
352                name: "id".to_owned(),
353            }],
354        }
355    }
356
357    #[test]
358    fn test_get_table_name_snake_case() {
359        let entity = setup();
360
361        assert_eq!(
362            entity.get_table_name_snake_case(),
363            "special_cake".to_owned()
364        );
365    }
366
367    #[test]
368    fn test_get_table_name_camel_case() {
369        let entity = setup();
370
371        assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
372    }
373
374    #[test]
375    fn test_get_table_name_snake_case_ident() {
376        let entity = setup();
377
378        assert_eq!(
379            entity.get_table_name_snake_case_ident(),
380            format_ident!("{}", "special_cake")
381        );
382    }
383
384    #[test]
385    fn test_get_table_name_camel_case_ident() {
386        let entity = setup();
387
388        assert_eq!(
389            entity.get_table_name_camel_case_ident(),
390            format_ident!("{}", "SpecialCake")
391        );
392    }
393
394    #[test]
395    fn test_get_column_names_snake_case() {
396        let entity = setup();
397
398        for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
399            assert_eq!(elem, entity.columns[i].get_name_snake_case());
400        }
401    }
402
403    #[test]
404    fn test_get_column_names_camel_case() {
405        let entity = setup();
406
407        for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
408            assert_eq!(elem, entity.columns[i].get_name_camel_case());
409        }
410    }
411
412    #[test]
413    fn test_get_column_rs_types() {
414        let entity = setup();
415        let opt = ColumnOption::default();
416
417        for (i, elem) in entity.get_column_rs_types(&opt).into_iter().enumerate() {
418            assert_eq!(
419                elem.to_string(),
420                entity.columns[i].get_rs_type(&opt).to_string()
421            );
422        }
423    }
424
425    #[test]
426    fn test_get_column_defs() {
427        let entity = setup();
428
429        for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
430            assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
431        }
432    }
433
434    #[test]
435    fn test_get_primary_key_names_snake_case() {
436        let entity = setup();
437
438        for (i, elem) in entity
439            .get_primary_key_names_snake_case()
440            .into_iter()
441            .enumerate()
442        {
443            assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
444        }
445    }
446
447    #[test]
448    fn test_get_primary_key_names_camel_case() {
449        let entity = setup();
450
451        for (i, elem) in entity
452            .get_primary_key_names_camel_case()
453            .into_iter()
454            .enumerate()
455        {
456            assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
457        }
458    }
459
460    #[test]
461    fn test_get_relation_module_name() {
462        let entity = setup();
463
464        for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
465            assert_eq!(elem, entity.relations[i].get_module_name());
466        }
467    }
468
469    #[test]
470    fn test_get_relation_enum_name() {
471        let entity = setup();
472
473        for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
474            assert_eq!(elem, entity.relations[i].get_enum_name());
475        }
476    }
477
478    #[test]
479    fn test_get_relation_defs() {
480        let entity = setup();
481
482        for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
483            assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
484        }
485    }
486
487    #[test]
488    fn test_get_relation_attrs() {
489        let entity = setup();
490
491        for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
492            assert_eq!(
493                elem.to_string(),
494                entity.relations[i].get_attrs().to_string()
495            );
496        }
497    }
498
499    #[test]
500    fn test_get_primary_key_auto_increment() {
501        let mut entity = setup();
502
503        assert_eq!(
504            entity.get_primary_key_auto_increment(),
505            format_ident!("{}", false)
506        );
507
508        entity.columns[0].auto_increment = true;
509        assert_eq!(
510            entity.get_primary_key_auto_increment(),
511            format_ident!("{}", true)
512        );
513    }
514
515    #[test]
516    fn test_get_primary_key_rs_type() {
517        let entity = setup();
518        let opt = Default::default();
519
520        assert_eq!(
521            entity.get_primary_key_rs_type(&opt).to_string(),
522            entity.columns[0].get_rs_type(&opt).to_string()
523        );
524    }
525
526    #[test]
527    fn test_get_conjunct_relations_via_snake_case() {
528        let entity = setup();
529
530        for (i, elem) in entity
531            .get_conjunct_relations_via_snake_case()
532            .into_iter()
533            .enumerate()
534        {
535            assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
536        }
537    }
538
539    #[test]
540    fn test_get_conjunct_relations_to_snake_case() {
541        let entity = setup();
542
543        for (i, elem) in entity
544            .get_conjunct_relations_to_snake_case()
545            .into_iter()
546            .enumerate()
547        {
548            assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
549        }
550    }
551
552    #[test]
553    fn test_get_conjunct_relations_to_upper_camel_case() {
554        let entity = setup();
555
556        for (i, elem) in entity
557            .get_conjunct_relations_to_upper_camel_case()
558            .into_iter()
559            .enumerate()
560        {
561            assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
562        }
563    }
564
565    #[test]
566    fn test_get_eq_needed() {
567        let entity = setup();
568        let expected = quote! {, Eq};
569
570        assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
571    }
572}