sea_orm_codegen/entity/
relation.rs

1use crate::util::unpack_table_ref;
2use heck::{ToSnakeCase, ToUpperCamelCase};
3use proc_macro2::{Ident, TokenStream};
4use quote::{format_ident, quote};
5use sea_query::{ForeignKeyAction, TableForeignKey};
6use syn::{punctuated::Punctuated, token::Comma};
7
8use crate::util::escape_rust_keyword;
9
10#[derive(Clone, Debug)]
11pub enum RelationType {
12    HasOne,
13    HasMany,
14    BelongsTo,
15}
16
17#[derive(Clone, Debug)]
18pub struct Relation {
19    pub(crate) ref_table: String,
20    pub(crate) columns: Vec<String>,
21    pub(crate) ref_columns: Vec<String>,
22    pub(crate) rel_type: RelationType,
23    pub(crate) on_update: Option<ForeignKeyAction>,
24    pub(crate) on_delete: Option<ForeignKeyAction>,
25    pub(crate) self_referencing: bool,
26    pub(crate) num_suffix: usize,
27    pub(crate) impl_related: bool,
28}
29
30impl Relation {
31    pub fn get_enum_name(&self) -> Ident {
32        let name = if self.self_referencing {
33            format_ident!("SelfRef")
34        } else {
35            format_ident!("{}", self.ref_table.to_upper_camel_case())
36        };
37        if self.num_suffix > 0 {
38            format_ident!("{}{}", name, self.num_suffix)
39        } else {
40            name
41        }
42    }
43
44    pub fn get_module_name(&self) -> Option<Ident> {
45        if self.self_referencing {
46            None
47        } else {
48            Some(format_ident!(
49                "{}",
50                escape_rust_keyword(self.ref_table.to_snake_case())
51            ))
52        }
53    }
54
55    pub fn get_def(&self) -> TokenStream {
56        let rel_type = self.get_rel_type();
57        let module_name = self.get_module_name();
58        let ref_entity = if module_name.is_some() {
59            quote! { super::#module_name::Entity }
60        } else {
61            quote! { Entity }
62        };
63        match self.rel_type {
64            RelationType::HasOne | RelationType::HasMany => {
65                quote! {
66                    Entity::#rel_type(#ref_entity).into()
67                }
68            }
69            RelationType::BelongsTo => {
70                let map_src_column = |src_column: &Ident| {
71                    quote! { Column::#src_column }
72                };
73                let map_ref_column = |ref_column: &Ident| {
74                    if module_name.is_some() {
75                        quote! { super::#module_name::Column::#ref_column }
76                    } else {
77                        quote! { Column::#ref_column }
78                    }
79                };
80                let map_punctuated =
81                    |punctuated: Punctuated<TokenStream, Comma>| match punctuated.len() {
82                        0..=1 => quote! { #punctuated },
83                        _ => quote! { (#punctuated) },
84                    };
85                let (from, to) =
86                    self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
87                quote! {
88                    Entity::#rel_type(#ref_entity)
89                        .from(#from)
90                        .to(#to)
91                        .into()
92                }
93            }
94        }
95    }
96
97    pub fn get_attrs(&self) -> TokenStream {
98        let rel_type = self.get_rel_type();
99        let module_name = if let Some(module_name) = self.get_module_name() {
100            format!("super::{module_name}::")
101        } else {
102            String::new()
103        };
104        let ref_entity = format!("{module_name}Entity");
105        match self.rel_type {
106            RelationType::HasOne | RelationType::HasMany => {
107                quote! {
108                    #[sea_orm(#rel_type = #ref_entity)]
109                }
110            }
111            RelationType::BelongsTo => {
112                let map_src_column = |src_column: &Ident| format!("Column::{src_column}");
113                let map_ref_column =
114                    |ref_column: &Ident| format!("{module_name}Column::{ref_column}");
115                let map_punctuated = |punctuated: Vec<String>| {
116                    let len = punctuated.len();
117                    let punctuated = punctuated.join(", ");
118                    match len {
119                        0..=1 => punctuated,
120                        _ => format!("({})", punctuated),
121                    }
122                };
123                let (from, to) =
124                    self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
125
126                let on_update = if let Some(action) = &self.on_update {
127                    let action = Self::get_foreign_key_action(action);
128                    quote! {
129                        on_update = #action,
130                    }
131                } else {
132                    quote! {}
133                };
134                let on_delete = if let Some(action) = &self.on_delete {
135                    let action = Self::get_foreign_key_action(action);
136                    quote! {
137                        on_delete = #action,
138                    }
139                } else {
140                    quote! {}
141                };
142                quote! {
143                    #[sea_orm(
144                        #rel_type = #ref_entity,
145                        from = #from,
146                        to = #to,
147                        #on_update
148                        #on_delete
149                    )]
150                }
151            }
152        }
153    }
154
155    pub fn get_rel_type(&self) -> Ident {
156        match self.rel_type {
157            RelationType::HasOne => format_ident!("has_one"),
158            RelationType::HasMany => format_ident!("has_many"),
159            RelationType::BelongsTo => format_ident!("belongs_to"),
160        }
161    }
162
163    pub fn get_column_camel_case(&self) -> Vec<Ident> {
164        self.columns
165            .iter()
166            .map(|col| format_ident!("{}", col.to_upper_camel_case()))
167            .collect()
168    }
169
170    pub fn get_ref_column_camel_case(&self) -> Vec<Ident> {
171        self.ref_columns
172            .iter()
173            .map(|col| format_ident!("{}", col.to_upper_camel_case()))
174            .collect()
175    }
176
177    pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String {
178        match action {
179            ForeignKeyAction::Restrict => "Restrict",
180            ForeignKeyAction::Cascade => "Cascade",
181            ForeignKeyAction::SetNull => "SetNull",
182            ForeignKeyAction::NoAction => "NoAction",
183            ForeignKeyAction::SetDefault => "SetDefault",
184        }
185        .to_owned()
186    }
187
188    pub fn get_src_ref_columns<F1, F2, F3, T, I>(
189        &self,
190        map_src_column: F1,
191        map_ref_column: F2,
192        map_punctuated: F3,
193    ) -> (T, T)
194    where
195        F1: Fn(&Ident) -> T,
196        F2: Fn(&Ident) -> T,
197        F3: Fn(I) -> T,
198        I: Extend<T> + Default,
199    {
200        let from: I =
201            self.get_column_camel_case()
202                .iter()
203                .fold(I::default(), |mut acc, src_column| {
204                    acc.extend([map_src_column(src_column)]);
205                    acc
206                });
207        let to: I =
208            self.get_ref_column_camel_case()
209                .iter()
210                .fold(I::default(), |mut acc, ref_column| {
211                    acc.extend([map_ref_column(ref_column)]);
212                    acc
213                });
214
215        (map_punctuated(from), map_punctuated(to))
216    }
217}
218
219impl From<&TableForeignKey> for Relation {
220    fn from(tbl_fk: &TableForeignKey) -> Self {
221        let ref_table = match tbl_fk.get_ref_table() {
222            Some(s) => unpack_table_ref(s),
223            None => panic!("RefTable should not be empty"),
224        };
225        let columns = tbl_fk.get_columns();
226        let ref_columns = tbl_fk.get_ref_columns();
227        let rel_type = RelationType::BelongsTo;
228        let on_delete = tbl_fk.get_on_delete();
229        let on_update = tbl_fk.get_on_update();
230        Self {
231            ref_table,
232            columns,
233            ref_columns,
234            rel_type,
235            on_delete,
236            on_update,
237            self_referencing: false,
238            num_suffix: 0,
239            impl_related: true,
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use crate::{Relation, RelationType};
247    use proc_macro2::TokenStream;
248    use sea_query::ForeignKeyAction;
249
250    fn setup() -> Vec<Relation> {
251        vec![
252            Relation {
253                ref_table: "fruit".to_owned(),
254                columns: vec!["id".to_owned()],
255                ref_columns: vec!["cake_id".to_owned()],
256                rel_type: RelationType::HasOne,
257                on_delete: None,
258                on_update: None,
259                self_referencing: false,
260                num_suffix: 0,
261                impl_related: true,
262            },
263            Relation {
264                ref_table: "filling".to_owned(),
265                columns: vec!["filling_id".to_owned()],
266                ref_columns: vec!["id".to_owned()],
267                rel_type: RelationType::BelongsTo,
268                on_delete: Some(ForeignKeyAction::Cascade),
269                on_update: Some(ForeignKeyAction::Cascade),
270                self_referencing: false,
271                num_suffix: 0,
272                impl_related: true,
273            },
274            Relation {
275                ref_table: "filling".to_owned(),
276                columns: vec!["filling_id".to_owned()],
277                ref_columns: vec!["id".to_owned()],
278                rel_type: RelationType::HasMany,
279                on_delete: Some(ForeignKeyAction::Cascade),
280                on_update: None,
281                self_referencing: false,
282                num_suffix: 0,
283                impl_related: true,
284            },
285        ]
286    }
287
288    #[test]
289    fn test_get_module_name() {
290        let relations = setup();
291        let snake_cases = vec!["fruit", "filling", "filling"];
292        for (rel, snake_case) in relations.into_iter().zip(snake_cases) {
293            assert_eq!(rel.get_module_name().unwrap().to_string(), snake_case);
294        }
295    }
296
297    #[test]
298    fn test_get_enum_name() {
299        let relations = setup();
300        let camel_cases = vec!["Fruit", "Filling", "Filling"];
301        for (rel, camel_case) in relations.into_iter().zip(camel_cases) {
302            assert_eq!(rel.get_enum_name().to_string(), camel_case);
303        }
304    }
305
306    #[test]
307    fn test_get_def() {
308        let relations = setup();
309        let rel_defs = vec![
310            "Entity::has_one(super::fruit::Entity).into()",
311            "Entity::belongs_to(super::filling::Entity) \
312                .from(Column::FillingId) \
313                .to(super::filling::Column::Id) \
314                .into()",
315            "Entity::has_many(super::filling::Entity).into()",
316        ];
317        for (rel, rel_def) in relations.into_iter().zip(rel_defs) {
318            let rel_def: TokenStream = rel_def.parse().unwrap();
319
320            assert_eq!(rel.get_def().to_string(), rel_def.to_string());
321        }
322    }
323
324    #[test]
325    fn test_get_rel_type() {
326        let relations = setup();
327        let rel_types = vec!["has_one", "belongs_to", "has_many"];
328        for (rel, rel_type) in relations.into_iter().zip(rel_types) {
329            assert_eq!(rel.get_rel_type(), rel_type);
330        }
331    }
332
333    #[test]
334    fn test_get_column_camel_case() {
335        let relations = setup();
336        let cols = vec!["Id", "FillingId", "FillingId"];
337        for (rel, col) in relations.into_iter().zip(cols) {
338            assert_eq!(rel.get_column_camel_case(), [col]);
339        }
340    }
341
342    #[test]
343    fn test_get_ref_column_camel_case() {
344        let relations = setup();
345        let ref_cols = vec!["CakeId", "Id", "Id"];
346        for (rel, ref_col) in relations.into_iter().zip(ref_cols) {
347            assert_eq!(rel.get_ref_column_camel_case(), [ref_col]);
348        }
349    }
350}