1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::{format_ident, quote};
4use sea_query::{ForeignKeyAction, TableForeignKey};
5use syn::{punctuated::Punctuated, token::Comma};
6
7use crate::util::escape_rust_keyword;
8
9#[derive(Debug, Clone, Copy)]
10pub enum RelationType {
11 HasOne,
12 HasMany,
13 BelongsTo,
14}
15
16#[derive(Clone, Debug)]
17pub struct Relation {
18 pub(crate) ref_table: String,
19 pub(crate) columns: Vec<String>,
20 pub(crate) ref_columns: Vec<String>,
21 pub(crate) rel_type: RelationType,
22 pub(crate) on_update: Option<ForeignKeyAction>,
23 pub(crate) on_delete: Option<ForeignKeyAction>,
24 pub(crate) self_referencing: bool,
25 pub(crate) num_suffix: usize,
26 pub(crate) impl_related: bool,
27}
28
29impl Relation {
30 pub fn get_enum_name(&self) -> Ident {
31 let name = if self.self_referencing {
32 format_ident!("SelfRef")
33 } else {
34 format_ident!("{}", self.ref_table.to_upper_camel_case())
35 };
36 if self.num_suffix > 0 {
37 format_ident!("{}{}", name, self.num_suffix)
38 } else {
39 name
40 }
41 }
42
43 pub fn get_module_name(&self) -> Option<Ident> {
44 if self.self_referencing {
45 None
46 } else {
47 Some(format_ident!(
48 "{}",
49 escape_rust_keyword(self.ref_table.to_snake_case())
50 ))
51 }
52 }
53
54 pub fn get_def(&self) -> TokenStream {
55 let rel_type = self.get_rel_type();
56 let module_name = self.get_module_name();
57 let ref_entity = if module_name.is_some() {
58 quote! { super::#module_name::Entity }
59 } else {
60 quote! { Entity }
61 };
62 match self.rel_type {
63 RelationType::HasOne | RelationType::HasMany => {
64 quote! {
65 Entity::#rel_type(#ref_entity).into()
66 }
67 }
68 RelationType::BelongsTo => {
69 let map_src_column = |src_column: &Ident| {
70 quote! { Column::#src_column }
71 };
72 let map_ref_column = |ref_column: &Ident| {
73 if module_name.is_some() {
74 quote! { super::#module_name::Column::#ref_column }
75 } else {
76 quote! { Column::#ref_column }
77 }
78 };
79 let map_punctuated =
80 |punctuated: Punctuated<TokenStream, Comma>| match punctuated.len() {
81 0..=1 => quote! { #punctuated },
82 _ => quote! { (#punctuated) },
83 };
84 let (from, to) =
85 self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
86 quote! {
87 Entity::#rel_type(#ref_entity)
88 .from(#from)
89 .to(#to)
90 .into()
91 }
92 }
93 }
94 }
95
96 pub fn get_attrs(&self) -> TokenStream {
97 let rel_type = self.get_rel_type();
98 let module_name = if let Some(module_name) = self.get_module_name() {
99 format!("super::{module_name}::")
100 } else {
101 String::new()
102 };
103 let ref_entity = format!("{module_name}Entity");
104 match self.rel_type {
105 RelationType::HasOne | RelationType::HasMany => {
106 quote! {
107 #[sea_orm(#rel_type = #ref_entity)]
108 }
109 }
110 RelationType::BelongsTo => {
111 let map_src_column = |src_column: &Ident| format!("Column::{src_column}");
112 let map_ref_column =
113 |ref_column: &Ident| format!("{module_name}Column::{ref_column}");
114 let map_punctuated = |punctuated: Vec<String>| {
115 let len = punctuated.len();
116 let punctuated = punctuated.join(", ");
117 match len {
118 0..=1 => punctuated,
119 _ => format!("({punctuated})"),
120 }
121 };
122 let (from, to) =
123 self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
124
125 let on_update = if let Some(action) = &self.on_update {
126 let action = Self::get_foreign_key_action(action);
127 quote! {
128 on_update = #action,
129 }
130 } else {
131 quote! {}
132 };
133 let on_delete = if let Some(action) = &self.on_delete {
134 let action = Self::get_foreign_key_action(action);
135 quote! {
136 on_delete = #action,
137 }
138 } else {
139 quote! {}
140 };
141 quote! {
142 #[sea_orm(
143 #rel_type = #ref_entity,
144 from = #from,
145 to = #to,
146 #on_update
147 #on_delete
148 )]
149 }
150 }
151 }
152 }
153
154 pub fn get_rel_type(&self) -> Ident {
155 match self.rel_type {
156 RelationType::HasOne => format_ident!("has_one"),
157 RelationType::HasMany => format_ident!("has_many"),
158 RelationType::BelongsTo => format_ident!("belongs_to"),
159 }
160 }
161
162 pub fn get_column_camel_case(&self) -> Vec<Ident> {
163 self.columns
164 .iter()
165 .map(|col| format_ident!("{}", col.to_upper_camel_case()))
166 .collect()
167 }
168
169 pub fn get_ref_column_camel_case(&self) -> Vec<Ident> {
170 self.ref_columns
171 .iter()
172 .map(|col| format_ident!("{}", col.to_upper_camel_case()))
173 .collect()
174 }
175
176 pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String {
177 action.variant_name().to_owned()
178 }
179
180 pub fn get_src_ref_columns<F1, F2, F3, T, I>(
181 &self,
182 map_src_column: F1,
183 map_ref_column: F2,
184 map_punctuated: F3,
185 ) -> (T, T)
186 where
187 F1: Fn(&Ident) -> T,
188 F2: Fn(&Ident) -> T,
189 F3: Fn(I) -> T,
190 I: Extend<T> + Default,
191 {
192 let from: I =
193 self.get_column_camel_case()
194 .iter()
195 .fold(I::default(), |mut acc, src_column| {
196 acc.extend([map_src_column(src_column)]);
197 acc
198 });
199 let to: I =
200 self.get_ref_column_camel_case()
201 .iter()
202 .fold(I::default(), |mut acc, ref_column| {
203 acc.extend([map_ref_column(ref_column)]);
204 acc
205 });
206
207 (map_punctuated(from), map_punctuated(to))
208 }
209}
210
211impl From<&TableForeignKey> for Relation {
212 fn from(tbl_fk: &TableForeignKey) -> Self {
213 let ref_table = match tbl_fk.get_ref_table() {
214 Some(s) => s.sea_orm_table().to_string(),
215 None => panic!("RefTable should not be empty"),
216 };
217 let columns = tbl_fk.get_columns();
218 let ref_columns = tbl_fk.get_ref_columns();
219 let rel_type = RelationType::BelongsTo;
220 let on_delete = tbl_fk.get_on_delete();
221 let on_update = tbl_fk.get_on_update();
222 Self {
223 ref_table,
224 columns,
225 ref_columns,
226 rel_type,
227 on_delete,
228 on_update,
229 self_referencing: false,
230 num_suffix: 0,
231 impl_related: true,
232 }
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use crate::{Relation, RelationType};
239 use proc_macro2::TokenStream;
240 use sea_query::ForeignKeyAction;
241
242 fn setup() -> Vec<Relation> {
243 vec![
244 Relation {
245 ref_table: "fruit".to_owned(),
246 columns: vec!["id".to_owned()],
247 ref_columns: vec!["cake_id".to_owned()],
248 rel_type: RelationType::HasOne,
249 on_delete: None,
250 on_update: None,
251 self_referencing: false,
252 num_suffix: 0,
253 impl_related: true,
254 },
255 Relation {
256 ref_table: "filling".to_owned(),
257 columns: vec!["filling_id".to_owned()],
258 ref_columns: vec!["id".to_owned()],
259 rel_type: RelationType::BelongsTo,
260 on_delete: Some(ForeignKeyAction::Cascade),
261 on_update: Some(ForeignKeyAction::Cascade),
262 self_referencing: false,
263 num_suffix: 0,
264 impl_related: true,
265 },
266 Relation {
267 ref_table: "filling".to_owned(),
268 columns: vec!["filling_id".to_owned()],
269 ref_columns: vec!["id".to_owned()],
270 rel_type: RelationType::HasMany,
271 on_delete: Some(ForeignKeyAction::Cascade),
272 on_update: None,
273 self_referencing: false,
274 num_suffix: 0,
275 impl_related: true,
276 },
277 ]
278 }
279
280 #[test]
281 fn test_get_module_name() {
282 let relations = setup();
283 let snake_cases = vec!["fruit", "filling", "filling"];
284 for (rel, snake_case) in relations.into_iter().zip(snake_cases) {
285 assert_eq!(rel.get_module_name().unwrap().to_string(), snake_case);
286 }
287 }
288
289 #[test]
290 fn test_get_enum_name() {
291 let relations = setup();
292 let camel_cases = vec!["Fruit", "Filling", "Filling"];
293 for (rel, camel_case) in relations.into_iter().zip(camel_cases) {
294 assert_eq!(rel.get_enum_name().to_string(), camel_case);
295 }
296 }
297
298 #[test]
299 fn test_get_def() {
300 let relations = setup();
301 let rel_defs = vec![
302 "Entity::has_one(super::fruit::Entity).into()",
303 "Entity::belongs_to(super::filling::Entity) \
304 .from(Column::FillingId) \
305 .to(super::filling::Column::Id) \
306 .into()",
307 "Entity::has_many(super::filling::Entity).into()",
308 ];
309 for (rel, rel_def) in relations.into_iter().zip(rel_defs) {
310 let rel_def: TokenStream = rel_def.parse().unwrap();
311
312 assert_eq!(rel.get_def().to_string(), rel_def.to_string());
313 }
314 }
315
316 #[test]
317 fn test_get_rel_type() {
318 let relations = setup();
319 let rel_types = vec!["has_one", "belongs_to", "has_many"];
320 for (rel, rel_type) in relations.into_iter().zip(rel_types) {
321 assert_eq!(rel.get_rel_type(), rel_type);
322 }
323 }
324
325 #[test]
326 fn test_get_column_camel_case() {
327 let relations = setup();
328 let cols = vec!["Id", "FillingId", "FillingId"];
329 for (rel, col) in relations.into_iter().zip(cols) {
330 assert_eq!(rel.get_column_camel_case(), [col]);
331 }
332 }
333
334 #[test]
335 fn test_get_ref_column_camel_case() {
336 let relations = setup();
337 let ref_cols = vec!["CakeId", "Id", "Id"];
338 for (rel, ref_col) in relations.into_iter().zip(ref_cols) {
339 assert_eq!(rel.get_ref_column_camel_case(), [ref_col]);
340 }
341 }
342}