1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8 Column, ConjunctRelation, DateTimeCrate, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13 pub(crate) table_name: String,
14 pub(crate) columns: Vec<Column>,
15 pub(crate) relations: Vec<Relation>,
16 pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17 pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21 pub fn get_table_name_snake_case(&self) -> String {
22 self.table_name.to_snake_case()
23 }
24
25 pub fn get_table_name_camel_case(&self) -> String {
26 self.table_name.to_upper_camel_case()
27 }
28
29 pub fn get_table_name_snake_case_ident(&self) -> Ident {
30 format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31 }
32
33 pub fn get_table_name_camel_case_ident(&self) -> Ident {
34 format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35 }
36
37 pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38 self.columns
39 .iter()
40 .map(|col| col.get_name_snake_case())
41 .collect()
42 }
43
44 pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45 self.columns
46 .iter()
47 .map(|col| col.get_name_camel_case())
48 .collect()
49 }
50
51 pub fn get_column_rs_types(&self, date_time_crate: &DateTimeCrate) -> Vec<TokenStream> {
52 self.columns
53 .clone()
54 .into_iter()
55 .map(|col| col.get_rs_type(date_time_crate))
56 .collect()
57 }
58
59 pub fn get_column_defs(&self) -> Vec<TokenStream> {
60 self.columns
61 .clone()
62 .into_iter()
63 .map(|col| col.get_def())
64 .collect()
65 }
66
67 pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68 self.primary_keys
69 .iter()
70 .map(|pk| pk.get_name_snake_case())
71 .collect()
72 }
73
74 pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75 self.primary_keys
76 .iter()
77 .map(|pk| pk.get_name_camel_case())
78 .collect()
79 }
80
81 pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82 self.relations
83 .iter()
84 .map(|rel| rel.get_module_name())
85 .collect()
86 }
87
88 pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89 self.relations
90 .iter()
91 .map(|rel| rel.get_enum_name())
92 .collect()
93 }
94
95 pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97 let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100 let self_relations_reverse = self
102 .relations
103 .iter()
104 .filter(|rel| rel.self_referencing)
105 .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107 self.get_relation_enum_name()
109 .into_iter()
110 .chain(self_relations_reverse)
111 .chain(conjunct_related_names)
112 .collect()
113 }
114
115 pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116 self.relations.iter().map(|rel| rel.get_def()).collect()
117 }
118
119 pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120 self.relations.iter().map(|rel| rel.get_attrs()).collect()
121 }
122
123 pub fn get_related_entity_modules(&self) -> Vec<Ident> {
125 let conjunct_related_attrs = self
127 .conjunct_relations
128 .iter()
129 .map(|conj| conj.get_to_snake_case());
130
131 let produce_relation_attrs = |rel: &Relation, _reverse: bool| match rel.get_module_name() {
133 Some(module_name) => module_name,
134 None => format_ident!("self"),
135 };
136
137 let self_relations_reverse_attrs = self
139 .relations
140 .iter()
141 .filter(|rel| rel.self_referencing)
142 .map(|rel| produce_relation_attrs(rel, true));
143
144 self.relations
146 .iter()
147 .map(|rel| produce_relation_attrs(rel, false))
148 .chain(self_relations_reverse_attrs)
149 .chain(conjunct_related_attrs)
150 .collect()
151 }
152
153 pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
155 let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
157 let entity = format!("super::{}::Entity", conj.get_to_snake_case());
158
159 quote! {
160 #[sea_orm(
161 entity = #entity
162 )]
163 }
164 });
165
166 let produce_relation_attrs = |rel: &Relation, reverse: bool| {
168 let entity = match rel.get_module_name() {
169 Some(module_name) => format!("super::{module_name}::Entity"),
170 None => String::from("Entity"),
171 };
172
173 if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
174 let def = if reverse {
175 format!("Relation::{}.def().rev()", rel.get_enum_name())
176 } else {
177 format!("Relation::{}.def()", rel.get_enum_name())
178 };
179
180 quote! {
181 #[sea_orm(
182 entity = #entity,
183 def = #def
184 )]
185 }
186 } else {
187 quote! {
188 #[sea_orm(
189 entity = #entity
190 )]
191 }
192 }
193 };
194
195 let self_relations_reverse_attrs = self
197 .relations
198 .iter()
199 .filter(|rel| rel.self_referencing)
200 .map(|rel| produce_relation_attrs(rel, true));
201
202 self.relations
204 .iter()
205 .map(|rel| produce_relation_attrs(rel, false))
206 .chain(self_relations_reverse_attrs)
207 .chain(conjunct_related_attrs)
208 .collect()
209 }
210
211 pub fn get_primary_key_auto_increment(&self) -> Ident {
212 let auto_increment = self.columns.iter().any(|col| col.auto_increment);
213 format_ident!("{}", auto_increment)
214 }
215
216 pub fn get_primary_key_rs_type(&self, date_time_crate: &DateTimeCrate) -> TokenStream {
217 let types = self
218 .primary_keys
219 .iter()
220 .map(|primary_key| {
221 self.columns
222 .iter()
223 .find(|col| col.name.eq(&primary_key.name))
224 .unwrap()
225 .get_rs_type(date_time_crate)
226 .to_string()
227 })
228 .collect::<Vec<_>>();
229 if !types.is_empty() {
230 let value_type = if types.len() > 1 {
231 vec!["(".to_owned(), types.join(", "), ")".to_owned()]
232 } else {
233 types
234 };
235 value_type.join("").parse().unwrap()
236 } else {
237 TokenStream::new()
238 }
239 }
240
241 pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
242 self.conjunct_relations
243 .iter()
244 .map(|con_rel| con_rel.get_via_snake_case())
245 .collect()
246 }
247
248 pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
249 self.conjunct_relations
250 .iter()
251 .map(|con_rel| con_rel.get_to_snake_case())
252 .collect()
253 }
254
255 pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
256 self.conjunct_relations
257 .iter()
258 .map(|con_rel| con_rel.get_to_upper_camel_case())
259 .collect()
260 }
261
262 pub fn get_eq_needed(&self) -> TokenStream {
263 fn is_floats(col_type: &ColumnType) -> bool {
264 match col_type {
265 ColumnType::Float | ColumnType::Double => true,
266 ColumnType::Array(col_type) => is_floats(col_type),
267 ColumnType::Vector(_) => true,
268 _ => false,
269 }
270 }
271 self.columns
272 .iter()
273 .find(|column| is_floats(&column.col_type))
274 .map_or(quote! {, Eq}, |_| quote! {})
277 }
278
279 pub fn get_column_serde_attributes(
280 &self,
281 serde_skip_deserializing_primary_key: bool,
282 serde_skip_hidden_column: bool,
283 ) -> Vec<TokenStream> {
284 self.columns
285 .iter()
286 .map(|col| {
287 let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
288 col.get_serde_attribute(
289 is_primary_key,
290 serde_skip_deserializing_primary_key,
291 serde_skip_hidden_column,
292 )
293 })
294 .collect()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use quote::{format_ident, quote};
301 use sea_query::{ColumnType, ForeignKeyAction, StringLen};
302
303 use crate::{Column, DateTimeCrate, Entity, PrimaryKey, Relation, RelationType};
304
305 fn setup() -> Entity {
306 Entity {
307 table_name: "special_cake".to_owned(),
308 columns: vec![
309 Column {
310 name: "id".to_owned(),
311 col_type: ColumnType::Integer,
312 auto_increment: false,
313 not_null: false,
314 unique: false,
315 },
316 Column {
317 name: "name".to_owned(),
318 col_type: ColumnType::String(StringLen::None),
319 auto_increment: false,
320 not_null: false,
321 unique: false,
322 },
323 ],
324 relations: vec![
325 Relation {
326 ref_table: "fruit".to_owned(),
327 columns: vec!["id".to_owned()],
328 ref_columns: vec!["cake_id".to_owned()],
329 rel_type: RelationType::HasOne,
330 on_delete: Some(ForeignKeyAction::Cascade),
331 on_update: Some(ForeignKeyAction::Cascade),
332 self_referencing: false,
333 num_suffix: 0,
334 impl_related: true,
335 },
336 Relation {
337 ref_table: "filling".to_owned(),
338 columns: vec!["id".to_owned()],
339 ref_columns: vec!["cake_id".to_owned()],
340 rel_type: RelationType::HasOne,
341 on_delete: Some(ForeignKeyAction::Cascade),
342 on_update: Some(ForeignKeyAction::Cascade),
343 self_referencing: false,
344 num_suffix: 0,
345 impl_related: true,
346 },
347 ],
348 conjunct_relations: vec![],
349 primary_keys: vec![PrimaryKey {
350 name: "id".to_owned(),
351 }],
352 }
353 }
354
355 #[test]
356 fn test_get_table_name_snake_case() {
357 let entity = setup();
358
359 assert_eq!(
360 entity.get_table_name_snake_case(),
361 "special_cake".to_owned()
362 );
363 }
364
365 #[test]
366 fn test_get_table_name_camel_case() {
367 let entity = setup();
368
369 assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
370 }
371
372 #[test]
373 fn test_get_table_name_snake_case_ident() {
374 let entity = setup();
375
376 assert_eq!(
377 entity.get_table_name_snake_case_ident(),
378 format_ident!("{}", "special_cake")
379 );
380 }
381
382 #[test]
383 fn test_get_table_name_camel_case_ident() {
384 let entity = setup();
385
386 assert_eq!(
387 entity.get_table_name_camel_case_ident(),
388 format_ident!("{}", "SpecialCake")
389 );
390 }
391
392 #[test]
393 fn test_get_column_names_snake_case() {
394 let entity = setup();
395
396 for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
397 assert_eq!(elem, entity.columns[i].get_name_snake_case());
398 }
399 }
400
401 #[test]
402 fn test_get_column_names_camel_case() {
403 let entity = setup();
404
405 for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
406 assert_eq!(elem, entity.columns[i].get_name_camel_case());
407 }
408 }
409
410 #[test]
411 fn test_get_column_rs_types() {
412 let entity = setup();
413
414 for (i, elem) in entity
415 .get_column_rs_types(&DateTimeCrate::Chrono)
416 .into_iter()
417 .enumerate()
418 {
419 assert_eq!(
420 elem.to_string(),
421 entity.columns[i]
422 .get_rs_type(&DateTimeCrate::Chrono)
423 .to_string()
424 );
425 }
426 }
427
428 #[test]
429 fn test_get_column_defs() {
430 let entity = setup();
431
432 for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
433 assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
434 }
435 }
436
437 #[test]
438 fn test_get_primary_key_names_snake_case() {
439 let entity = setup();
440
441 for (i, elem) in entity
442 .get_primary_key_names_snake_case()
443 .into_iter()
444 .enumerate()
445 {
446 assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
447 }
448 }
449
450 #[test]
451 fn test_get_primary_key_names_camel_case() {
452 let entity = setup();
453
454 for (i, elem) in entity
455 .get_primary_key_names_camel_case()
456 .into_iter()
457 .enumerate()
458 {
459 assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
460 }
461 }
462
463 #[test]
464 fn test_get_relation_module_name() {
465 let entity = setup();
466
467 for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
468 assert_eq!(elem, entity.relations[i].get_module_name());
469 }
470 }
471
472 #[test]
473 fn test_get_relation_enum_name() {
474 let entity = setup();
475
476 for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
477 assert_eq!(elem, entity.relations[i].get_enum_name());
478 }
479 }
480
481 #[test]
482 fn test_get_relation_defs() {
483 let entity = setup();
484
485 for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
486 assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
487 }
488 }
489
490 #[test]
491 fn test_get_relation_attrs() {
492 let entity = setup();
493
494 for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
495 assert_eq!(
496 elem.to_string(),
497 entity.relations[i].get_attrs().to_string()
498 );
499 }
500 }
501
502 #[test]
503 fn test_get_primary_key_auto_increment() {
504 let mut entity = setup();
505
506 assert_eq!(
507 entity.get_primary_key_auto_increment(),
508 format_ident!("{}", false)
509 );
510
511 entity.columns[0].auto_increment = true;
512 assert_eq!(
513 entity.get_primary_key_auto_increment(),
514 format_ident!("{}", true)
515 );
516 }
517
518 #[test]
519 fn test_get_primary_key_rs_type() {
520 let entity = setup();
521
522 assert_eq!(
523 entity
524 .get_primary_key_rs_type(&DateTimeCrate::Chrono)
525 .to_string(),
526 entity.columns[0]
527 .get_rs_type(&DateTimeCrate::Chrono)
528 .to_string()
529 );
530 }
531
532 #[test]
533 fn test_get_conjunct_relations_via_snake_case() {
534 let entity = setup();
535
536 for (i, elem) in entity
537 .get_conjunct_relations_via_snake_case()
538 .into_iter()
539 .enumerate()
540 {
541 assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
542 }
543 }
544
545 #[test]
546 fn test_get_conjunct_relations_to_snake_case() {
547 let entity = setup();
548
549 for (i, elem) in entity
550 .get_conjunct_relations_to_snake_case()
551 .into_iter()
552 .enumerate()
553 {
554 assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
555 }
556 }
557
558 #[test]
559 fn test_get_conjunct_relations_to_upper_camel_case() {
560 let entity = setup();
561
562 for (i, elem) in entity
563 .get_conjunct_relations_to_upper_camel_case()
564 .into_iter()
565 .enumerate()
566 {
567 assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
568 }
569 }
570
571 #[test]
572 fn test_get_eq_needed() {
573 let entity = setup();
574 let expected = quote! {, Eq};
575
576 assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
577 }
578}