1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8 util::escape_rust_keyword, Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13 pub(crate) table_name: String,
14 pub(crate) columns: Vec<Column>,
15 pub(crate) relations: Vec<Relation>,
16 pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17 pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21 pub fn get_table_name_snake_case(&self) -> String {
22 self.table_name.to_snake_case()
23 }
24
25 pub fn get_table_name_camel_case(&self) -> String {
26 self.table_name.to_upper_camel_case()
27 }
28
29 pub fn get_table_name_snake_case_ident(&self) -> Ident {
30 format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31 }
32
33 pub fn get_table_name_camel_case_ident(&self) -> Ident {
34 format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35 }
36
37 pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38 self.columns
39 .iter()
40 .map(|col| col.get_name_snake_case())
41 .collect()
42 }
43
44 pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45 self.columns
46 .iter()
47 .map(|col| col.get_name_camel_case())
48 .collect()
49 }
50
51 pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52 self.columns
53 .clone()
54 .into_iter()
55 .map(|col| col.get_rs_type(date_time_crate))
56 .collect()
57 }
58
59 pub fn get_column_defs(&self) -> Vec<TokenStream> {
60 self.columns
61 .clone()
62 .into_iter()
63 .map(|col| col.get_def())
64 .collect()
65 }
66
67 pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68 self.primary_keys
69 .iter()
70 .map(|pk| pk.get_name_snake_case())
71 .collect()
72 }
73
74 pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75 self.primary_keys
76 .iter()
77 .map(|pk| pk.get_name_camel_case())
78 .collect()
79 }
80
81 pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82 self.relations
83 .iter()
84 .map(|rel| rel.get_module_name())
85 .collect()
86 }
87
88 pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89 self.relations
90 .iter()
91 .map(|rel| rel.get_enum_name())
92 .collect()
93 }
94
95 pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97 let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100 let self_relations_reverse = self
102 .relations
103 .iter()
104 .filter(|rel| rel.self_referencing)
105 .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107 self.get_relation_enum_name()
109 .into_iter()
110 .chain(self_relations_reverse)
111 .chain(conjunct_related_names)
112 .collect()
113 }
114
115 pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116 self.relations.iter().map(|rel| rel.get_def()).collect()
117 }
118
119 pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120 self.relations.iter().map(|rel| rel.get_attrs()).collect()
121 }
122
123 pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
125 let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
127 let entity = format!("super::{}::Entity", conj.get_to_snake_case());
128
129 quote! {
130 #[sea_orm(
131 entity = #entity
132 )]
133 }
134 });
135
136 let produce_relation_attrs = |rel: &Relation, reverse: bool| {
138 let entity = match rel.get_module_name() {
139 Some(module_name) => format!("super::{}::Entity", module_name),
140 None => String::from("Entity"),
141 };
142
143 if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
144 let def = if reverse {
145 format!("Relation::{}.def().rev()", rel.get_enum_name())
146 } else {
147 format!("Relation::{}.def()", rel.get_enum_name())
148 };
149
150 quote! {
151 #[sea_orm(
152 entity = #entity,
153 def = #def
154 )]
155 }
156 } else {
157 quote! {
158 #[sea_orm(
159 entity = #entity
160 )]
161 }
162 }
163 };
164
165 let self_relations_reverse_attrs = self
167 .relations
168 .iter()
169 .filter(|rel| rel.self_referencing)
170 .map(|rel| produce_relation_attrs(rel, true));
171
172 self.relations
174 .iter()
175 .map(|rel| produce_relation_attrs(rel, false))
176 .chain(self_relations_reverse_attrs)
177 .chain(conjunct_related_attrs)
178 .collect()
179 }
180
181 pub fn get_primary_key_auto_increment(&self) -> Ident {
182 let auto_increment = self.columns.iter().any(|col| col.auto_increment);
183 format_ident!("{}", auto_increment)
184 }
185
186 pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
187 let types = self
188 .primary_keys
189 .iter()
190 .map(|primary_key| {
191 self.columns
192 .iter()
193 .find(|col| col.name.eq(&primary_key.name))
194 .unwrap()
195 .get_rs_type(date_time_crate)
196 .to_string()
197 })
198 .collect::<Vec<_>>();
199 if !types.is_empty() {
200 let value_type = if types.len() > 1 {
201 vec!["(".to_owned(), types.join(", "), ")".to_owned()]
202 } else {
203 types
204 };
205 value_type.join("").parse().unwrap()
206 } else {
207 TokenStream::new()
208 }
209 }
210
211 pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
212 self.conjunct_relations
213 .iter()
214 .map(|con_rel| con_rel.get_via_snake_case())
215 .collect()
216 }
217
218 pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
219 self.conjunct_relations
220 .iter()
221 .map(|con_rel| con_rel.get_to_snake_case())
222 .collect()
223 }
224
225 pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
226 self.conjunct_relations
227 .iter()
228 .map(|con_rel| con_rel.get_to_upper_camel_case())
229 .collect()
230 }
231
232 pub fn get_eq_needed(&self) -> TokenStream {
233 fn is_floats(col_type: &ColumnType) -> bool {
234 match col_type {
235 ColumnType::Float | ColumnType::Double => true,
236 ColumnType::Array(col_type) => is_floats(col_type),
237 _ => false,
238 }
239 }
240 self.columns
241 .iter()
242 .find(|column| is_floats(&column.col_type))
243 .map_or(quote! {, Eq}, |_| quote! {})
246 }
247
248 pub fn get_column_serde_attributes(
249 &self,
250 serde_skip_deserializing_primary_key: bool,
251 serde_skip_hidden_column: bool,
252 ) -> Vec<TokenStream> {
253 self.columns
254 .iter()
255 .map(|col| {
256 let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
257 col.get_serde_attribute(
258 is_primary_key,
259 serde_skip_deserializing_primary_key,
260 serde_skip_hidden_column,
261 )
262 })
263 .collect()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use quote::{format_ident, quote};
270 use sea_query::{ColumnType, ForeignKeyAction, StringLen};
271
272 use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
273
274 fn setup() -> Entity {
275 Entity {
276 table_name: "special_cake".to_owned(),
277 columns: vec![
278 Column {
279 name: "id".to_owned(),
280 col_type: ColumnType::Integer,
281 auto_increment: false,
282 not_null: false,
283 unique: false,
284 },
285 Column {
286 name: "name".to_owned(),
287 col_type: ColumnType::String(StringLen::None),
288 auto_increment: false,
289 not_null: false,
290 unique: false,
291 },
292 ],
293 relations: vec![
294 Relation {
295 ref_table: "fruit".to_owned(),
296 columns: vec!["id".to_owned()],
297 ref_columns: vec!["cake_id".to_owned()],
298 rel_type: RelationType::HasOne,
299 on_delete: Some(ForeignKeyAction::Cascade),
300 on_update: Some(ForeignKeyAction::Cascade),
301 self_referencing: false,
302 num_suffix: 0,
303 impl_related: true,
304 },
305 Relation {
306 ref_table: "filling".to_owned(),
307 columns: vec!["id".to_owned()],
308 ref_columns: vec!["cake_id".to_owned()],
309 rel_type: RelationType::HasOne,
310 on_delete: Some(ForeignKeyAction::Cascade),
311 on_update: Some(ForeignKeyAction::Cascade),
312 self_referencing: false,
313 num_suffix: 0,
314 impl_related: true,
315 },
316 ],
317 conjunct_relations: vec![],
318 primary_keys: vec![PrimaryKey {
319 name: "id".to_owned(),
320 }],
321 }
322 }
323
324 #[test]
325 fn test_get_table_name_snake_case() {
326 let entity = setup();
327
328 assert_eq!(
329 entity.get_table_name_snake_case(),
330 "special_cake".to_owned()
331 );
332 }
333
334 #[test]
335 fn test_get_table_name_camel_case() {
336 let entity = setup();
337
338 assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
339 }
340
341 #[test]
342 fn test_get_table_name_snake_case_ident() {
343 let entity = setup();
344
345 assert_eq!(
346 entity.get_table_name_snake_case_ident(),
347 format_ident!("{}", "special_cake")
348 );
349 }
350
351 #[test]
352 fn test_get_table_name_camel_case_ident() {
353 let entity = setup();
354
355 assert_eq!(
356 entity.get_table_name_camel_case_ident(),
357 format_ident!("{}", "SpecialCake")
358 );
359 }
360
361 #[test]
362 fn test_get_column_names_snake_case() {
363 let entity = setup();
364
365 for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
366 assert_eq!(elem, entity.columns[i].get_name_snake_case());
367 }
368 }
369
370 #[test]
371 fn test_get_column_names_camel_case() {
372 let entity = setup();
373
374 for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
375 assert_eq!(elem, entity.columns[i].get_name_camel_case());
376 }
377 }
378
379 #[test]
380 fn test_get_column_rs_types() {
381 let entity = setup();
382
383 for (i, elem) in entity
384 .get_column_rs_types(&DateTimeCrate::Chrono)
385 .into_iter()
386 .enumerate()
387 {
388 assert_eq!(
389 elem.to_string(),
390 entity.columns[i]
391 .get_rs_type(&DateTimeCrate::Chrono)
392 .to_string()
393 );
394 }
395 }
396
397 #[test]
398 fn test_get_column_defs() {
399 let entity = setup();
400
401 for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
402 assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
403 }
404 }
405
406 #[test]
407 fn test_get_primary_key_names_snake_case() {
408 let entity = setup();
409
410 for (i, elem) in entity
411 .get_primary_key_names_snake_case()
412 .into_iter()
413 .enumerate()
414 {
415 assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
416 }
417 }
418
419 #[test]
420 fn test_get_primary_key_names_camel_case() {
421 let entity = setup();
422
423 for (i, elem) in entity
424 .get_primary_key_names_camel_case()
425 .into_iter()
426 .enumerate()
427 {
428 assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
429 }
430 }
431
432 #[test]
433 fn test_get_relation_module_name() {
434 let entity = setup();
435
436 for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
437 assert_eq!(elem, entity.relations[i].get_module_name());
438 }
439 }
440
441 #[test]
442 fn test_get_relation_enum_name() {
443 let entity = setup();
444
445 for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
446 assert_eq!(elem, entity.relations[i].get_enum_name());
447 }
448 }
449
450 #[test]
451 fn test_get_relation_defs() {
452 let entity = setup();
453
454 for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
455 assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
456 }
457 }
458
459 #[test]
460 fn test_get_relation_attrs() {
461 let entity = setup();
462
463 for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
464 assert_eq!(
465 elem.to_string(),
466 entity.relations[i].get_attrs().to_string()
467 );
468 }
469 }
470
471 #[test]
472 fn test_get_primary_key_auto_increment() {
473 let mut entity = setup();
474
475 assert_eq!(
476 entity.get_primary_key_auto_increment(),
477 format_ident!("{}", false)
478 );
479
480 entity.columns[0].auto_increment = true;
481 assert_eq!(
482 entity.get_primary_key_auto_increment(),
483 format_ident!("{}", true)
484 );
485 }
486
487 #[test]
488 fn test_get_primary_key_rs_type() {
489 let entity = setup();
490
491 assert_eq!(
492 entity
493 .get_primary_key_rs_type(&DateTimeCrate::Chrono)
494 .to_string(),
495 entity.columns[0]
496 .get_rs_type(&DateTimeCrate::Chrono)
497 .to_string()
498 );
499 }
500
501 #[test]
502 fn test_get_conjunct_relations_via_snake_case() {
503 let entity = setup();
504
505 for (i, elem) in entity
506 .get_conjunct_relations_via_snake_case()
507 .into_iter()
508 .enumerate()
509 {
510 assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
511 }
512 }
513
514 #[test]
515 fn test_get_conjunct_relations_to_snake_case() {
516 let entity = setup();
517
518 for (i, elem) in entity
519 .get_conjunct_relations_to_snake_case()
520 .into_iter()
521 .enumerate()
522 {
523 assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
524 }
525 }
526
527 #[test]
528 fn test_get_conjunct_relations_to_upper_camel_case() {
529 let entity = setup();
530
531 for (i, elem) in entity
532 .get_conjunct_relations_to_upper_camel_case()
533 .into_iter()
534 .enumerate()
535 {
536 assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
537 }
538 }
539
540 #[test]
541 fn test_get_eq_needed() {
542 let entity = setup();
543 let expected = quote! {, Eq};
544
545 assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
546 }
547}