sea_orm_codegen/entity/
relation.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::{format_ident, quote};
4use sea_query::{ForeignKeyAction, TableForeignKey};
5use syn::{punctuated::Punctuated, token::Comma};
6
7use crate::util::escape_rust_keyword;
8
9#[derive(Debug, Clone, Copy)]
10pub enum RelationType {
11    HasOne,
12    HasMany,
13    BelongsTo,
14}
15
16#[derive(Clone, Debug)]
17pub struct Relation {
18    pub(crate) ref_table: String,
19    pub(crate) columns: Vec<String>,
20    pub(crate) ref_columns: Vec<String>,
21    pub(crate) rel_type: RelationType,
22    pub(crate) on_update: Option<ForeignKeyAction>,
23    pub(crate) on_delete: Option<ForeignKeyAction>,
24    pub(crate) self_referencing: bool,
25    pub(crate) num_suffix: usize,
26    pub(crate) impl_related: bool,
27}
28
29impl Relation {
30    pub fn get_enum_name(&self) -> Ident {
31        let name = if self.self_referencing {
32            format_ident!("SelfRef")
33        } else {
34            format_ident!("{}", self.ref_table.to_upper_camel_case())
35        };
36        if self.num_suffix > 0 {
37            format_ident!("{}{}", name, self.num_suffix)
38        } else {
39            name
40        }
41    }
42
43    pub fn get_module_name(&self) -> Option<Ident> {
44        if self.self_referencing {
45            None
46        } else {
47            Some(format_ident!(
48                "{}",
49                escape_rust_keyword(self.ref_table.to_snake_case())
50            ))
51        }
52    }
53
54    pub fn get_def(&self) -> TokenStream {
55        let rel_type = self.get_rel_type();
56        let module_name = self.get_module_name();
57        let ref_entity = if module_name.is_some() {
58            quote! { super::#module_name::Entity }
59        } else {
60            quote! { Entity }
61        };
62        match self.rel_type {
63            RelationType::HasOne | RelationType::HasMany => {
64                quote! {
65                    Entity::#rel_type(#ref_entity).into()
66                }
67            }
68            RelationType::BelongsTo => {
69                let map_src_column = |src_column: &Ident| {
70                    quote! { Column::#src_column }
71                };
72                let map_ref_column = |ref_column: &Ident| {
73                    if module_name.is_some() {
74                        quote! { super::#module_name::Column::#ref_column }
75                    } else {
76                        quote! { Column::#ref_column }
77                    }
78                };
79                let map_punctuated =
80                    |punctuated: Punctuated<TokenStream, Comma>| match punctuated.len() {
81                        0..=1 => quote! { #punctuated },
82                        _ => quote! { (#punctuated) },
83                    };
84                let (from, to) =
85                    self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
86                quote! {
87                    Entity::#rel_type(#ref_entity)
88                        .from(#from)
89                        .to(#to)
90                        .into()
91                }
92            }
93        }
94    }
95
96    pub fn get_attrs(&self) -> TokenStream {
97        let rel_type = self.get_rel_type();
98        let module_name = if let Some(module_name) = self.get_module_name() {
99            format!("super::{module_name}::")
100        } else {
101            String::new()
102        };
103        let ref_entity = format!("{module_name}Entity");
104        match self.rel_type {
105            RelationType::HasOne | RelationType::HasMany => {
106                quote! {
107                    #[sea_orm(#rel_type = #ref_entity)]
108                }
109            }
110            RelationType::BelongsTo => {
111                let map_src_column = |src_column: &Ident| format!("Column::{src_column}");
112                let map_ref_column =
113                    |ref_column: &Ident| format!("{module_name}Column::{ref_column}");
114                let map_punctuated = |punctuated: Vec<String>| {
115                    let len = punctuated.len();
116                    let punctuated = punctuated.join(", ");
117                    match len {
118                        0..=1 => punctuated,
119                        _ => format!("({punctuated})"),
120                    }
121                };
122                let (from, to) =
123                    self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
124
125                let on_update = if let Some(action) = &self.on_update {
126                    let action = Self::get_foreign_key_action(action);
127                    quote! {
128                        on_update = #action,
129                    }
130                } else {
131                    quote! {}
132                };
133                let on_delete = if let Some(action) = &self.on_delete {
134                    let action = Self::get_foreign_key_action(action);
135                    quote! {
136                        on_delete = #action,
137                    }
138                } else {
139                    quote! {}
140                };
141                quote! {
142                    #[sea_orm(
143                        #rel_type = #ref_entity,
144                        from = #from,
145                        to = #to,
146                        #on_update
147                        #on_delete
148                    )]
149                }
150            }
151        }
152    }
153
154    pub fn get_rel_type(&self) -> Ident {
155        match self.rel_type {
156            RelationType::HasOne => format_ident!("has_one"),
157            RelationType::HasMany => format_ident!("has_many"),
158            RelationType::BelongsTo => format_ident!("belongs_to"),
159        }
160    }
161
162    pub fn get_column_camel_case(&self) -> Vec<Ident> {
163        self.columns
164            .iter()
165            .map(|col| format_ident!("{}", col.to_upper_camel_case()))
166            .collect()
167    }
168
169    pub fn get_ref_column_camel_case(&self) -> Vec<Ident> {
170        self.ref_columns
171            .iter()
172            .map(|col| format_ident!("{}", col.to_upper_camel_case()))
173            .collect()
174    }
175
176    pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String {
177        action.variant_name().to_owned()
178    }
179
180    pub fn get_src_ref_columns<F1, F2, F3, T, I>(
181        &self,
182        map_src_column: F1,
183        map_ref_column: F2,
184        map_punctuated: F3,
185    ) -> (T, T)
186    where
187        F1: Fn(&Ident) -> T,
188        F2: Fn(&Ident) -> T,
189        F3: Fn(I) -> T,
190        I: Extend<T> + Default,
191    {
192        let from: I =
193            self.get_column_camel_case()
194                .iter()
195                .fold(I::default(), |mut acc, src_column| {
196                    acc.extend([map_src_column(src_column)]);
197                    acc
198                });
199        let to: I =
200            self.get_ref_column_camel_case()
201                .iter()
202                .fold(I::default(), |mut acc, ref_column| {
203                    acc.extend([map_ref_column(ref_column)]);
204                    acc
205                });
206
207        (map_punctuated(from), map_punctuated(to))
208    }
209}
210
211impl From<&TableForeignKey> for Relation {
212    fn from(tbl_fk: &TableForeignKey) -> Self {
213        let ref_table = match tbl_fk.get_ref_table() {
214            Some(s) => s.sea_orm_table().to_string(),
215            None => panic!("RefTable should not be empty"),
216        };
217        let columns = tbl_fk.get_columns();
218        let ref_columns = tbl_fk.get_ref_columns();
219        let rel_type = RelationType::BelongsTo;
220        let on_delete = tbl_fk.get_on_delete();
221        let on_update = tbl_fk.get_on_update();
222        Self {
223            ref_table,
224            columns,
225            ref_columns,
226            rel_type,
227            on_delete,
228            on_update,
229            self_referencing: false,
230            num_suffix: 0,
231            impl_related: true,
232        }
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use crate::{Relation, RelationType};
239    use proc_macro2::TokenStream;
240    use sea_query::ForeignKeyAction;
241
242    fn setup() -> Vec<Relation> {
243        vec![
244            Relation {
245                ref_table: "fruit".to_owned(),
246                columns: vec!["id".to_owned()],
247                ref_columns: vec!["cake_id".to_owned()],
248                rel_type: RelationType::HasOne,
249                on_delete: None,
250                on_update: None,
251                self_referencing: false,
252                num_suffix: 0,
253                impl_related: true,
254            },
255            Relation {
256                ref_table: "filling".to_owned(),
257                columns: vec!["filling_id".to_owned()],
258                ref_columns: vec!["id".to_owned()],
259                rel_type: RelationType::BelongsTo,
260                on_delete: Some(ForeignKeyAction::Cascade),
261                on_update: Some(ForeignKeyAction::Cascade),
262                self_referencing: false,
263                num_suffix: 0,
264                impl_related: true,
265            },
266            Relation {
267                ref_table: "filling".to_owned(),
268                columns: vec!["filling_id".to_owned()],
269                ref_columns: vec!["id".to_owned()],
270                rel_type: RelationType::HasMany,
271                on_delete: Some(ForeignKeyAction::Cascade),
272                on_update: None,
273                self_referencing: false,
274                num_suffix: 0,
275                impl_related: true,
276            },
277        ]
278    }
279
280    #[test]
281    fn test_get_module_name() {
282        let relations = setup();
283        let snake_cases = vec!["fruit", "filling", "filling"];
284        for (rel, snake_case) in relations.into_iter().zip(snake_cases) {
285            assert_eq!(rel.get_module_name().unwrap().to_string(), snake_case);
286        }
287    }
288
289    #[test]
290    fn test_get_enum_name() {
291        let relations = setup();
292        let camel_cases = vec!["Fruit", "Filling", "Filling"];
293        for (rel, camel_case) in relations.into_iter().zip(camel_cases) {
294            assert_eq!(rel.get_enum_name().to_string(), camel_case);
295        }
296    }
297
298    #[test]
299    fn test_get_def() {
300        let relations = setup();
301        let rel_defs = vec![
302            "Entity::has_one(super::fruit::Entity).into()",
303            "Entity::belongs_to(super::filling::Entity) \
304                .from(Column::FillingId) \
305                .to(super::filling::Column::Id) \
306                .into()",
307            "Entity::has_many(super::filling::Entity).into()",
308        ];
309        for (rel, rel_def) in relations.into_iter().zip(rel_defs) {
310            let rel_def: TokenStream = rel_def.parse().unwrap();
311
312            assert_eq!(rel.get_def().to_string(), rel_def.to_string());
313        }
314    }
315
316    #[test]
317    fn test_get_rel_type() {
318        let relations = setup();
319        let rel_types = vec!["has_one", "belongs_to", "has_many"];
320        for (rel, rel_type) in relations.into_iter().zip(rel_types) {
321            assert_eq!(rel.get_rel_type(), rel_type);
322        }
323    }
324
325    #[test]
326    fn test_get_column_camel_case() {
327        let relations = setup();
328        let cols = vec!["Id", "FillingId", "FillingId"];
329        for (rel, col) in relations.into_iter().zip(cols) {
330            assert_eq!(rel.get_column_camel_case(), [col]);
331        }
332    }
333
334    #[test]
335    fn test_get_ref_column_camel_case() {
336        let relations = setup();
337        let ref_cols = vec!["CakeId", "Id", "Id"];
338        for (rel, ref_col) in relations.into_iter().zip(ref_cols) {
339            assert_eq!(rel.get_ref_column_camel_case(), [ref_col]);
340        }
341    }
342}