1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8 Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13 pub(crate) table_name: String,
14 pub(crate) columns: Vec<Column>,
15 pub(crate) relations: Vec<Relation>,
16 pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17 pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21 pub fn get_table_name_snake_case(&self) -> String {
22 self.table_name.to_snake_case()
23 }
24
25 pub fn get_table_name_camel_case(&self) -> String {
26 self.table_name.to_upper_camel_case()
27 }
28
29 pub fn get_table_name_snake_case_ident(&self) -> Ident {
30 format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31 }
32
33 pub fn get_table_name_camel_case_ident(&self) -> Ident {
34 format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35 }
36
37 pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38 self.columns
39 .iter()
40 .map(|col| col.get_name_snake_case())
41 .collect()
42 }
43
44 pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45 self.columns
46 .iter()
47 .map(|col| col.get_name_camel_case())
48 .collect()
49 }
50
51 pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52 self.columns
53 .clone()
54 .into_iter()
55 .map(|col| col.get_rs_type(date_time_crate))
56 .collect()
57 }
58
59 pub fn get_column_defs(&self) -> Vec<TokenStream> {
60 self.columns
61 .clone()
62 .into_iter()
63 .map(|col| col.get_def())
64 .collect()
65 }
66
67 pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68 self.primary_keys
69 .iter()
70 .map(|pk| pk.get_name_snake_case())
71 .collect()
72 }
73
74 pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75 self.primary_keys
76 .iter()
77 .map(|pk| pk.get_name_camel_case())
78 .collect()
79 }
80
81 pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82 self.relations
83 .iter()
84 .map(|rel| rel.get_module_name())
85 .collect()
86 }
87
88 pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89 self.relations
90 .iter()
91 .map(|rel| rel.get_enum_name())
92 .collect()
93 }
94
95 pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97 let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100 let self_relations_reverse = self
102 .relations
103 .iter()
104 .filter(|rel| rel.self_referencing)
105 .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107 self.get_relation_enum_name()
109 .into_iter()
110 .chain(self_relations_reverse)
111 .chain(conjunct_related_names)
112 .collect()
113 }
114
115 pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116 self.relations.iter().map(|rel| rel.get_def()).collect()
117 }
118
119 pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120 self.relations.iter().map(|rel| rel.get_attrs()).collect()
121 }
122
123 pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
125 let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
127 let entity = format!("super::{}::Entity", conj.get_to_snake_case());
128
129 quote! {
130 #[sea_orm(
131 entity = #entity
132 )]
133 }
134 });
135
136 let produce_relation_attrs = |rel: &Relation, reverse: bool| {
138 let entity = match rel.get_module_name() {
139 Some(module_name) => format!("super::{}::Entity", module_name),
140 None => String::from("Entity"),
141 };
142
143 if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
144 let def = if reverse {
145 format!("Relation::{}.def().rev()", rel.get_enum_name())
146 } else {
147 format!("Relation::{}.def()", rel.get_enum_name())
148 };
149
150 quote! {
151 #[sea_orm(
152 entity = #entity,
153 def = #def
154 )]
155 }
156 } else {
157 quote! {
158 #[sea_orm(
159 entity = #entity
160 )]
161 }
162 }
163 };
164
165 let self_relations_reverse_attrs = self
167 .relations
168 .iter()
169 .filter(|rel| rel.self_referencing)
170 .map(|rel| produce_relation_attrs(rel, true));
171
172 self.relations
174 .iter()
175 .map(|rel| produce_relation_attrs(rel, false))
176 .chain(self_relations_reverse_attrs)
177 .chain(conjunct_related_attrs)
178 .collect()
179 }
180
181 pub fn get_primary_key_auto_increment(&self) -> Ident {
182 let auto_increment = self.columns.iter().any(|col| col.auto_increment);
183 format_ident!("{}", auto_increment)
184 }
185
186 pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
187 let types = self
188 .primary_keys
189 .iter()
190 .map(|primary_key| {
191 self.columns
192 .iter()
193 .find(|col| col.name.eq(&primary_key.name))
194 .unwrap()
195 .get_rs_type(date_time_crate)
196 .to_string()
197 })
198 .collect::<Vec<_>>();
199 if !types.is_empty() {
200 let value_type = if types.len() > 1 {
201 vec!["(".to_owned(), types.join(", "), ")".to_owned()]
202 } else {
203 types
204 };
205 value_type.join("").parse().unwrap()
206 } else {
207 TokenStream::new()
208 }
209 }
210
211 pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
212 self.conjunct_relations
213 .iter()
214 .map(|con_rel| con_rel.get_via_snake_case())
215 .collect()
216 }
217
218 pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
219 self.conjunct_relations
220 .iter()
221 .map(|con_rel| con_rel.get_to_snake_case())
222 .collect()
223 }
224
225 pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
226 self.conjunct_relations
227 .iter()
228 .map(|con_rel| con_rel.get_to_upper_camel_case())
229 .collect()
230 }
231
232 pub fn get_eq_needed(&self) -> TokenStream {
233 fn is_floats(col_type: &ColumnType) -> bool {
234 match col_type {
235 ColumnType::Float | ColumnType::Double => true,
236 ColumnType::Array(col_type) => is_floats(col_type),
237 ColumnType::Vector(_) => true,
238 _ => false,
239 }
240 }
241 self.columns
242 .iter()
243 .find(|column| is_floats(&column.col_type))
244 .map_or(quote! {, Eq}, |_| quote! {})
247 }
248
249 pub fn get_column_serde_attributes(
250 &self,
251 serde_skip_deserializing_primary_key: bool,
252 serde_skip_hidden_column: bool,
253 ) -> Vec<TokenStream> {
254 self.columns
255 .iter()
256 .map(|col| {
257 let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
258 col.get_serde_attribute(
259 is_primary_key,
260 serde_skip_deserializing_primary_key,
261 serde_skip_hidden_column,
262 )
263 })
264 .collect()
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use quote::{format_ident, quote};
271 use sea_query::{ColumnType, ForeignKeyAction, StringLen};
272
273 use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
274
275 fn setup() -> Entity {
276 Entity {
277 table_name: "special_cake".to_owned(),
278 columns: vec![
279 Column {
280 name: "id".to_owned(),
281 col_type: ColumnType::Integer,
282 auto_increment: false,
283 not_null: false,
284 unique: false,
285 },
286 Column {
287 name: "name".to_owned(),
288 col_type: ColumnType::String(StringLen::None),
289 auto_increment: false,
290 not_null: false,
291 unique: false,
292 },
293 ],
294 relations: vec![
295 Relation {
296 ref_table: "fruit".to_owned(),
297 columns: vec!["id".to_owned()],
298 ref_columns: vec!["cake_id".to_owned()],
299 rel_type: RelationType::HasOne,
300 on_delete: Some(ForeignKeyAction::Cascade),
301 on_update: Some(ForeignKeyAction::Cascade),
302 self_referencing: false,
303 num_suffix: 0,
304 impl_related: true,
305 },
306 Relation {
307 ref_table: "filling".to_owned(),
308 columns: vec!["id".to_owned()],
309 ref_columns: vec!["cake_id".to_owned()],
310 rel_type: RelationType::HasOne,
311 on_delete: Some(ForeignKeyAction::Cascade),
312 on_update: Some(ForeignKeyAction::Cascade),
313 self_referencing: false,
314 num_suffix: 0,
315 impl_related: true,
316 },
317 ],
318 conjunct_relations: vec![],
319 primary_keys: vec![PrimaryKey {
320 name: "id".to_owned(),
321 }],
322 }
323 }
324
325 #[test]
326 fn test_get_table_name_snake_case() {
327 let entity = setup();
328
329 assert_eq!(
330 entity.get_table_name_snake_case(),
331 "special_cake".to_owned()
332 );
333 }
334
335 #[test]
336 fn test_get_table_name_camel_case() {
337 let entity = setup();
338
339 assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
340 }
341
342 #[test]
343 fn test_get_table_name_snake_case_ident() {
344 let entity = setup();
345
346 assert_eq!(
347 entity.get_table_name_snake_case_ident(),
348 format_ident!("{}", "special_cake")
349 );
350 }
351
352 #[test]
353 fn test_get_table_name_camel_case_ident() {
354 let entity = setup();
355
356 assert_eq!(
357 entity.get_table_name_camel_case_ident(),
358 format_ident!("{}", "SpecialCake")
359 );
360 }
361
362 #[test]
363 fn test_get_column_names_snake_case() {
364 let entity = setup();
365
366 for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
367 assert_eq!(elem, entity.columns[i].get_name_snake_case());
368 }
369 }
370
371 #[test]
372 fn test_get_column_names_camel_case() {
373 let entity = setup();
374
375 for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
376 assert_eq!(elem, entity.columns[i].get_name_camel_case());
377 }
378 }
379
380 #[test]
381 fn test_get_column_rs_types() {
382 let entity = setup();
383
384 for (i, elem) in entity
385 .get_column_rs_types(&DateTimeCrate::Chrono)
386 .into_iter()
387 .enumerate()
388 {
389 assert_eq!(
390 elem.to_string(),
391 entity.columns[i]
392 .get_rs_type(&DateTimeCrate::Chrono)
393 .to_string()
394 );
395 }
396 }
397
398 #[test]
399 fn test_get_column_defs() {
400 let entity = setup();
401
402 for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
403 assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
404 }
405 }
406
407 #[test]
408 fn test_get_primary_key_names_snake_case() {
409 let entity = setup();
410
411 for (i, elem) in entity
412 .get_primary_key_names_snake_case()
413 .into_iter()
414 .enumerate()
415 {
416 assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
417 }
418 }
419
420 #[test]
421 fn test_get_primary_key_names_camel_case() {
422 let entity = setup();
423
424 for (i, elem) in entity
425 .get_primary_key_names_camel_case()
426 .into_iter()
427 .enumerate()
428 {
429 assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
430 }
431 }
432
433 #[test]
434 fn test_get_relation_module_name() {
435 let entity = setup();
436
437 for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
438 assert_eq!(elem, entity.relations[i].get_module_name());
439 }
440 }
441
442 #[test]
443 fn test_get_relation_enum_name() {
444 let entity = setup();
445
446 for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
447 assert_eq!(elem, entity.relations[i].get_enum_name());
448 }
449 }
450
451 #[test]
452 fn test_get_relation_defs() {
453 let entity = setup();
454
455 for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
456 assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
457 }
458 }
459
460 #[test]
461 fn test_get_relation_attrs() {
462 let entity = setup();
463
464 for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
465 assert_eq!(
466 elem.to_string(),
467 entity.relations[i].get_attrs().to_string()
468 );
469 }
470 }
471
472 #[test]
473 fn test_get_primary_key_auto_increment() {
474 let mut entity = setup();
475
476 assert_eq!(
477 entity.get_primary_key_auto_increment(),
478 format_ident!("{}", false)
479 );
480
481 entity.columns[0].auto_increment = true;
482 assert_eq!(
483 entity.get_primary_key_auto_increment(),
484 format_ident!("{}", true)
485 );
486 }
487
488 #[test]
489 fn test_get_primary_key_rs_type() {
490 let entity = setup();
491
492 assert_eq!(
493 entity
494 .get_primary_key_rs_type(&DateTimeCrate::Chrono)
495 .to_string(),
496 entity.columns[0]
497 .get_rs_type(&DateTimeCrate::Chrono)
498 .to_string()
499 );
500 }
501
502 #[test]
503 fn test_get_conjunct_relations_via_snake_case() {
504 let entity = setup();
505
506 for (i, elem) in entity
507 .get_conjunct_relations_via_snake_case()
508 .into_iter()
509 .enumerate()
510 {
511 assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
512 }
513 }
514
515 #[test]
516 fn test_get_conjunct_relations_to_snake_case() {
517 let entity = setup();
518
519 for (i, elem) in entity
520 .get_conjunct_relations_to_snake_case()
521 .into_iter()
522 .enumerate()
523 {
524 assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
525 }
526 }
527
528 #[test]
529 fn test_get_conjunct_relations_to_upper_camel_case() {
530 let entity = setup();
531
532 for (i, elem) in entity
533 .get_conjunct_relations_to_upper_camel_case()
534 .into_iter()
535 .enumerate()
536 {
537 assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
538 }
539 }
540
541 #[test]
542 fn test_get_eq_needed() {
543 let entity = setup();
544 let expected = quote! {, Eq};
545
546 assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
547 }
548}