1use crate::util::unpack_table_ref;
2use heck::{ToSnakeCase, ToUpperCamelCase};
3use proc_macro2::{Ident, TokenStream};
4use quote::{format_ident, quote};
5use sea_query::{ForeignKeyAction, TableForeignKey};
6use syn::{punctuated::Punctuated, token::Comma};
7
8use crate::util::escape_rust_keyword;
9
10#[derive(Clone, Debug)]
11pub enum RelationType {
12 HasOne,
13 HasMany,
14 BelongsTo,
15}
16
17#[derive(Clone, Debug)]
18pub struct Relation {
19 pub(crate) ref_table: String,
20 pub(crate) columns: Vec<String>,
21 pub(crate) ref_columns: Vec<String>,
22 pub(crate) rel_type: RelationType,
23 pub(crate) on_update: Option<ForeignKeyAction>,
24 pub(crate) on_delete: Option<ForeignKeyAction>,
25 pub(crate) self_referencing: bool,
26 pub(crate) num_suffix: usize,
27 pub(crate) impl_related: bool,
28}
29
30impl Relation {
31 pub fn get_enum_name(&self) -> Ident {
32 let name = if self.self_referencing {
33 format_ident!("SelfRef")
34 } else {
35 format_ident!("{}", self.ref_table.to_upper_camel_case())
36 };
37 if self.num_suffix > 0 {
38 format_ident!("{}{}", name, self.num_suffix)
39 } else {
40 name
41 }
42 }
43
44 pub fn get_module_name(&self) -> Option<Ident> {
45 if self.self_referencing {
46 None
47 } else {
48 Some(format_ident!(
49 "{}",
50 escape_rust_keyword(self.ref_table.to_snake_case())
51 ))
52 }
53 }
54
55 pub fn get_def(&self) -> TokenStream {
56 let rel_type = self.get_rel_type();
57 let module_name = self.get_module_name();
58 let ref_entity = if module_name.is_some() {
59 quote! { super::#module_name::Entity }
60 } else {
61 quote! { Entity }
62 };
63 match self.rel_type {
64 RelationType::HasOne | RelationType::HasMany => {
65 quote! {
66 Entity::#rel_type(#ref_entity).into()
67 }
68 }
69 RelationType::BelongsTo => {
70 let map_src_column = |src_column: &Ident| {
71 quote! { Column::#src_column }
72 };
73 let map_ref_column = |ref_column: &Ident| {
74 if module_name.is_some() {
75 quote! { super::#module_name::Column::#ref_column }
76 } else {
77 quote! { Column::#ref_column }
78 }
79 };
80 let map_punctuated =
81 |punctuated: Punctuated<TokenStream, Comma>| match punctuated.len() {
82 0..=1 => quote! { #punctuated },
83 _ => quote! { (#punctuated) },
84 };
85 let (from, to) =
86 self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
87 quote! {
88 Entity::#rel_type(#ref_entity)
89 .from(#from)
90 .to(#to)
91 .into()
92 }
93 }
94 }
95 }
96
97 pub fn get_attrs(&self) -> TokenStream {
98 let rel_type = self.get_rel_type();
99 let module_name = if let Some(module_name) = self.get_module_name() {
100 format!("super::{module_name}::")
101 } else {
102 String::new()
103 };
104 let ref_entity = format!("{module_name}Entity");
105 match self.rel_type {
106 RelationType::HasOne | RelationType::HasMany => {
107 quote! {
108 #[sea_orm(#rel_type = #ref_entity)]
109 }
110 }
111 RelationType::BelongsTo => {
112 let map_src_column = |src_column: &Ident| format!("Column::{src_column}");
113 let map_ref_column =
114 |ref_column: &Ident| format!("{module_name}Column::{ref_column}");
115 let map_punctuated = |punctuated: Vec<String>| {
116 let len = punctuated.len();
117 let punctuated = punctuated.join(", ");
118 match len {
119 0..=1 => punctuated,
120 _ => format!("({})", punctuated),
121 }
122 };
123 let (from, to) =
124 self.get_src_ref_columns(map_src_column, map_ref_column, map_punctuated);
125
126 let on_update = if let Some(action) = &self.on_update {
127 let action = Self::get_foreign_key_action(action);
128 quote! {
129 on_update = #action,
130 }
131 } else {
132 quote! {}
133 };
134 let on_delete = if let Some(action) = &self.on_delete {
135 let action = Self::get_foreign_key_action(action);
136 quote! {
137 on_delete = #action,
138 }
139 } else {
140 quote! {}
141 };
142 quote! {
143 #[sea_orm(
144 #rel_type = #ref_entity,
145 from = #from,
146 to = #to,
147 #on_update
148 #on_delete
149 )]
150 }
151 }
152 }
153 }
154
155 pub fn get_rel_type(&self) -> Ident {
156 match self.rel_type {
157 RelationType::HasOne => format_ident!("has_one"),
158 RelationType::HasMany => format_ident!("has_many"),
159 RelationType::BelongsTo => format_ident!("belongs_to"),
160 }
161 }
162
163 pub fn get_column_camel_case(&self) -> Vec<Ident> {
164 self.columns
165 .iter()
166 .map(|col| format_ident!("{}", col.to_upper_camel_case()))
167 .collect()
168 }
169
170 pub fn get_ref_column_camel_case(&self) -> Vec<Ident> {
171 self.ref_columns
172 .iter()
173 .map(|col| format_ident!("{}", col.to_upper_camel_case()))
174 .collect()
175 }
176
177 pub fn get_foreign_key_action(action: &ForeignKeyAction) -> String {
178 match action {
179 ForeignKeyAction::Restrict => "Restrict",
180 ForeignKeyAction::Cascade => "Cascade",
181 ForeignKeyAction::SetNull => "SetNull",
182 ForeignKeyAction::NoAction => "NoAction",
183 ForeignKeyAction::SetDefault => "SetDefault",
184 }
185 .to_owned()
186 }
187
188 pub fn get_src_ref_columns<F1, F2, F3, T, I>(
189 &self,
190 map_src_column: F1,
191 map_ref_column: F2,
192 map_punctuated: F3,
193 ) -> (T, T)
194 where
195 F1: Fn(&Ident) -> T,
196 F2: Fn(&Ident) -> T,
197 F3: Fn(I) -> T,
198 I: Extend<T> + Default,
199 {
200 let from: I =
201 self.get_column_camel_case()
202 .iter()
203 .fold(I::default(), |mut acc, src_column| {
204 acc.extend([map_src_column(src_column)]);
205 acc
206 });
207 let to: I =
208 self.get_ref_column_camel_case()
209 .iter()
210 .fold(I::default(), |mut acc, ref_column| {
211 acc.extend([map_ref_column(ref_column)]);
212 acc
213 });
214
215 (map_punctuated(from), map_punctuated(to))
216 }
217}
218
219impl From<&TableForeignKey> for Relation {
220 fn from(tbl_fk: &TableForeignKey) -> Self {
221 let ref_table = match tbl_fk.get_ref_table() {
222 Some(s) => unpack_table_ref(s),
223 None => panic!("RefTable should not be empty"),
224 };
225 let columns = tbl_fk.get_columns();
226 let ref_columns = tbl_fk.get_ref_columns();
227 let rel_type = RelationType::BelongsTo;
228 let on_delete = tbl_fk.get_on_delete();
229 let on_update = tbl_fk.get_on_update();
230 Self {
231 ref_table,
232 columns,
233 ref_columns,
234 rel_type,
235 on_delete,
236 on_update,
237 self_referencing: false,
238 num_suffix: 0,
239 impl_related: true,
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use crate::{Relation, RelationType};
247 use proc_macro2::TokenStream;
248 use sea_query::ForeignKeyAction;
249
250 fn setup() -> Vec<Relation> {
251 vec![
252 Relation {
253 ref_table: "fruit".to_owned(),
254 columns: vec!["id".to_owned()],
255 ref_columns: vec!["cake_id".to_owned()],
256 rel_type: RelationType::HasOne,
257 on_delete: None,
258 on_update: None,
259 self_referencing: false,
260 num_suffix: 0,
261 impl_related: true,
262 },
263 Relation {
264 ref_table: "filling".to_owned(),
265 columns: vec!["filling_id".to_owned()],
266 ref_columns: vec!["id".to_owned()],
267 rel_type: RelationType::BelongsTo,
268 on_delete: Some(ForeignKeyAction::Cascade),
269 on_update: Some(ForeignKeyAction::Cascade),
270 self_referencing: false,
271 num_suffix: 0,
272 impl_related: true,
273 },
274 Relation {
275 ref_table: "filling".to_owned(),
276 columns: vec!["filling_id".to_owned()],
277 ref_columns: vec!["id".to_owned()],
278 rel_type: RelationType::HasMany,
279 on_delete: Some(ForeignKeyAction::Cascade),
280 on_update: None,
281 self_referencing: false,
282 num_suffix: 0,
283 impl_related: true,
284 },
285 ]
286 }
287
288 #[test]
289 fn test_get_module_name() {
290 let relations = setup();
291 let snake_cases = vec!["fruit", "filling", "filling"];
292 for (rel, snake_case) in relations.into_iter().zip(snake_cases) {
293 assert_eq!(rel.get_module_name().unwrap().to_string(), snake_case);
294 }
295 }
296
297 #[test]
298 fn test_get_enum_name() {
299 let relations = setup();
300 let camel_cases = vec!["Fruit", "Filling", "Filling"];
301 for (rel, camel_case) in relations.into_iter().zip(camel_cases) {
302 assert_eq!(rel.get_enum_name().to_string(), camel_case);
303 }
304 }
305
306 #[test]
307 fn test_get_def() {
308 let relations = setup();
309 let rel_defs = vec![
310 "Entity::has_one(super::fruit::Entity).into()",
311 "Entity::belongs_to(super::filling::Entity) \
312 .from(Column::FillingId) \
313 .to(super::filling::Column::Id) \
314 .into()",
315 "Entity::has_many(super::filling::Entity).into()",
316 ];
317 for (rel, rel_def) in relations.into_iter().zip(rel_defs) {
318 let rel_def: TokenStream = rel_def.parse().unwrap();
319
320 assert_eq!(rel.get_def().to_string(), rel_def.to_string());
321 }
322 }
323
324 #[test]
325 fn test_get_rel_type() {
326 let relations = setup();
327 let rel_types = vec!["has_one", "belongs_to", "has_many"];
328 for (rel, rel_type) in relations.into_iter().zip(rel_types) {
329 assert_eq!(rel.get_rel_type(), rel_type);
330 }
331 }
332
333 #[test]
334 fn test_get_column_camel_case() {
335 let relations = setup();
336 let cols = vec!["Id", "FillingId", "FillingId"];
337 for (rel, col) in relations.into_iter().zip(cols) {
338 assert_eq!(rel.get_column_camel_case(), [col]);
339 }
340 }
341
342 #[test]
343 fn test_get_ref_column_camel_case() {
344 let relations = setup();
345 let ref_cols = vec!["CakeId", "Id", "Id"];
346 for (rel, ref_col) in relations.into_iter().zip(ref_cols) {
347 assert_eq!(rel.get_ref_column_camel_case(), [ref_col]);
348 }
349 }
350}