1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::{Ident, TokenStream};
3use quote::format_ident;
4use quote::quote;
5use sea_query::ColumnType;
6
7use crate::{
8 Column, ColumnOption, ConjunctRelation, PrimaryKey, Relation, util::escape_rust_keyword,
9};
10
11#[derive(Clone, Debug)]
12pub struct Entity {
13 pub(crate) table_name: String,
14 pub(crate) columns: Vec<Column>,
15 pub(crate) relations: Vec<Relation>,
16 pub(crate) conjunct_relations: Vec<ConjunctRelation>,
17 pub(crate) primary_keys: Vec<PrimaryKey>,
18}
19
20impl Entity {
21 pub fn get_table_name_snake_case(&self) -> String {
22 self.table_name.to_snake_case()
23 }
24
25 pub fn get_table_name_camel_case(&self) -> String {
26 self.table_name.to_upper_camel_case()
27 }
28
29 pub fn get_table_name_snake_case_ident(&self) -> Ident {
30 format_ident!("{}", escape_rust_keyword(self.get_table_name_snake_case()))
31 }
32
33 pub fn get_table_name_camel_case_ident(&self) -> Ident {
34 format_ident!("{}", escape_rust_keyword(self.get_table_name_camel_case()))
35 }
36
37 pub fn get_column_names_snake_case(&self) -> Vec<Ident> {
38 self.columns
39 .iter()
40 .map(|col| col.get_name_snake_case())
41 .collect()
42 }
43
44 pub fn get_column_names_camel_case(&self) -> Vec<Ident> {
45 self.columns
46 .iter()
47 .map(|col| col.get_name_camel_case())
48 .collect()
49 }
50
51 pub fn get_column_rs_types(&self, opt: &ColumnOption) -> Vec<TokenStream> {
52 self.columns
53 .clone()
54 .into_iter()
55 .map(|col| col.get_rs_type(opt))
56 .collect()
57 }
58
59 pub fn get_column_defs(&self) -> Vec<TokenStream> {
60 self.columns
61 .clone()
62 .into_iter()
63 .map(|col| col.get_def())
64 .collect()
65 }
66
67 pub fn get_primary_key_names_snake_case(&self) -> Vec<Ident> {
68 self.primary_keys
69 .iter()
70 .map(|pk| pk.get_name_snake_case())
71 .collect()
72 }
73
74 pub fn get_primary_key_names_camel_case(&self) -> Vec<Ident> {
75 self.primary_keys
76 .iter()
77 .map(|pk| pk.get_name_camel_case())
78 .collect()
79 }
80
81 pub fn get_relation_module_name(&self) -> Vec<Option<Ident>> {
82 self.relations
83 .iter()
84 .map(|rel| rel.get_module_name())
85 .collect()
86 }
87
88 pub fn get_relation_enum_name(&self) -> Vec<Ident> {
89 self.relations
90 .iter()
91 .map(|rel| rel.get_enum_name())
92 .collect()
93 }
94
95 pub fn get_related_entity_enum_name(&self) -> Vec<Ident> {
97 let conjunct_related_names = self.get_conjunct_relations_to_upper_camel_case();
99
100 let self_relations_reverse = self
102 .relations
103 .iter()
104 .filter(|rel| rel.self_referencing)
105 .map(|rel| format_ident!("{}Reverse", rel.get_enum_name()));
106
107 self.get_relation_enum_name()
109 .into_iter()
110 .chain(self_relations_reverse)
111 .chain(conjunct_related_names)
112 .collect()
113 }
114
115 pub fn get_relation_defs(&self) -> Vec<TokenStream> {
116 self.relations.iter().map(|rel| rel.get_def()).collect()
117 }
118
119 pub fn get_relation_attrs(&self) -> Vec<TokenStream> {
120 self.relations.iter().map(|rel| rel.get_attrs()).collect()
121 }
122
123 pub fn get_related_entity_modules(&self) -> Vec<Ident> {
125 let conjunct_related_attrs = self
127 .conjunct_relations
128 .iter()
129 .map(|conj| conj.get_to_snake_case());
130
131 let produce_relation_attrs = |rel: &Relation, _reverse: bool| match rel.get_module_name() {
133 Some(module_name) => module_name,
134 None => format_ident!("self"),
135 };
136
137 let self_relations_reverse_attrs = self
139 .relations
140 .iter()
141 .filter(|rel| rel.self_referencing)
142 .map(|rel| produce_relation_attrs(rel, true));
143
144 self.relations
146 .iter()
147 .map(|rel| produce_relation_attrs(rel, false))
148 .chain(self_relations_reverse_attrs)
149 .chain(conjunct_related_attrs)
150 .collect()
151 }
152
153 pub fn get_related_entity_attrs(&self) -> Vec<TokenStream> {
155 let conjunct_related_attrs = self.conjunct_relations.iter().map(|conj| {
157 let entity = format!("super::{}::Entity", conj.get_to_snake_case());
158
159 quote! {
160 #[sea_orm(
161 entity = #entity
162 )]
163 }
164 });
165
166 let produce_relation_attrs = |rel: &Relation, reverse: bool| {
168 let entity = match rel.get_module_name() {
169 Some(module_name) => format!("super::{module_name}::Entity"),
170 None => String::from("Entity"),
171 };
172
173 if rel.self_referencing || !rel.impl_related || rel.num_suffix > 0 {
174 let def = if reverse {
175 format!("Relation::{}.def().rev()", rel.get_enum_name())
176 } else {
177 format!("Relation::{}.def()", rel.get_enum_name())
178 };
179
180 quote! {
181 #[sea_orm(
182 entity = #entity,
183 def = #def
184 )]
185 }
186 } else {
187 quote! {
188 #[sea_orm(
189 entity = #entity
190 )]
191 }
192 }
193 };
194
195 let self_relations_reverse_attrs = self
197 .relations
198 .iter()
199 .filter(|rel| rel.self_referencing)
200 .map(|rel| produce_relation_attrs(rel, true));
201
202 self.relations
204 .iter()
205 .map(|rel| produce_relation_attrs(rel, false))
206 .chain(self_relations_reverse_attrs)
207 .chain(conjunct_related_attrs)
208 .collect()
209 }
210
211 pub fn get_primary_key_auto_increment(&self) -> Ident {
212 let auto_increment = self.columns.iter().any(|col| col.auto_increment);
213 format_ident!("{}", auto_increment)
214 }
215
216 pub fn get_primary_key_rs_type(&self, opt: &ColumnOption) -> TokenStream {
217 let types = self
218 .primary_keys
219 .iter()
220 .map(|primary_key| {
221 self.columns
222 .iter()
223 .find(|col| col.name.eq(&primary_key.name))
224 .unwrap()
225 .get_rs_type(opt)
226 .to_string()
227 })
228 .collect::<Vec<_>>();
229 if !types.is_empty() {
230 let value_type = if types.len() > 1 {
231 vec!["(".to_owned(), types.join(", "), ")".to_owned()]
232 } else {
233 types
234 };
235 value_type.join("").parse().unwrap()
236 } else {
237 TokenStream::new()
238 }
239 }
240
241 pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
242 self.conjunct_relations
243 .iter()
244 .map(|con_rel| con_rel.get_via_snake_case())
245 .collect()
246 }
247
248 pub fn get_conjunct_relations_to_snake_case(&self) -> Vec<Ident> {
249 self.conjunct_relations
250 .iter()
251 .map(|con_rel| con_rel.get_to_snake_case())
252 .collect()
253 }
254
255 pub fn get_conjunct_relations_to_upper_camel_case(&self) -> Vec<Ident> {
256 self.conjunct_relations
257 .iter()
258 .map(|con_rel| con_rel.get_to_upper_camel_case())
259 .collect()
260 }
261
262 pub fn get_eq_needed(&self) -> TokenStream {
263 fn is_floats(col_type: &ColumnType) -> bool {
264 match col_type {
265 ColumnType::Float | ColumnType::Double => true,
266 ColumnType::Array(col_type) => is_floats(col_type),
267 ColumnType::Vector(_) => true,
268 _ => false,
269 }
270 }
271 self.columns
272 .iter()
273 .find(|column| is_floats(&column.col_type))
274 .map_or(quote! {, Eq}, |_| quote! {})
277 }
278
279 pub fn get_column_serde_attributes(
280 &self,
281 serde_skip_deserializing_primary_key: bool,
282 serde_skip_hidden_column: bool,
283 ) -> Vec<TokenStream> {
284 self.columns
285 .iter()
286 .map(|col| {
287 let is_primary_key = self.primary_keys.iter().any(|pk| pk.name == col.name);
288 col.get_serde_attribute(
289 is_primary_key,
290 serde_skip_deserializing_primary_key,
291 serde_skip_hidden_column,
292 )
293 })
294 .collect()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use quote::{format_ident, quote};
301 use sea_query::{ColumnType, ForeignKeyAction, StringLen};
302
303 use crate::{Column, ColumnOption, Entity, PrimaryKey, Relation, RelationType};
304
305 fn setup() -> Entity {
306 Entity {
307 table_name: "special_cake".to_owned(),
308 columns: vec![
309 Column {
310 name: "id".to_owned(),
311 col_type: ColumnType::Integer,
312 auto_increment: false,
313 not_null: false,
314 unique: false,
315 unique_key: None,
316 },
317 Column {
318 name: "name".to_owned(),
319 col_type: ColumnType::String(StringLen::None),
320 auto_increment: false,
321 not_null: false,
322 unique: false,
323 unique_key: None,
324 },
325 ],
326 relations: vec![
327 Relation {
328 ref_table: "fruit".to_owned(),
329 columns: vec!["id".to_owned()],
330 ref_columns: vec!["cake_id".to_owned()],
331 rel_type: RelationType::HasOne,
332 on_delete: Some(ForeignKeyAction::Cascade),
333 on_update: Some(ForeignKeyAction::Cascade),
334 self_referencing: false,
335 num_suffix: 0,
336 impl_related: true,
337 },
338 Relation {
339 ref_table: "filling".to_owned(),
340 columns: vec!["id".to_owned()],
341 ref_columns: vec!["cake_id".to_owned()],
342 rel_type: RelationType::HasOne,
343 on_delete: Some(ForeignKeyAction::Cascade),
344 on_update: Some(ForeignKeyAction::Cascade),
345 self_referencing: false,
346 num_suffix: 0,
347 impl_related: true,
348 },
349 ],
350 conjunct_relations: vec![],
351 primary_keys: vec![PrimaryKey {
352 name: "id".to_owned(),
353 }],
354 }
355 }
356
357 #[test]
358 fn test_get_table_name_snake_case() {
359 let entity = setup();
360
361 assert_eq!(
362 entity.get_table_name_snake_case(),
363 "special_cake".to_owned()
364 );
365 }
366
367 #[test]
368 fn test_get_table_name_camel_case() {
369 let entity = setup();
370
371 assert_eq!(entity.get_table_name_camel_case(), "SpecialCake".to_owned());
372 }
373
374 #[test]
375 fn test_get_table_name_snake_case_ident() {
376 let entity = setup();
377
378 assert_eq!(
379 entity.get_table_name_snake_case_ident(),
380 format_ident!("{}", "special_cake")
381 );
382 }
383
384 #[test]
385 fn test_get_table_name_camel_case_ident() {
386 let entity = setup();
387
388 assert_eq!(
389 entity.get_table_name_camel_case_ident(),
390 format_ident!("{}", "SpecialCake")
391 );
392 }
393
394 #[test]
395 fn test_get_column_names_snake_case() {
396 let entity = setup();
397
398 for (i, elem) in entity.get_column_names_snake_case().into_iter().enumerate() {
399 assert_eq!(elem, entity.columns[i].get_name_snake_case());
400 }
401 }
402
403 #[test]
404 fn test_get_column_names_camel_case() {
405 let entity = setup();
406
407 for (i, elem) in entity.get_column_names_camel_case().into_iter().enumerate() {
408 assert_eq!(elem, entity.columns[i].get_name_camel_case());
409 }
410 }
411
412 #[test]
413 fn test_get_column_rs_types() {
414 let entity = setup();
415 let opt = ColumnOption::default();
416
417 for (i, elem) in entity.get_column_rs_types(&opt).into_iter().enumerate() {
418 assert_eq!(
419 elem.to_string(),
420 entity.columns[i].get_rs_type(&opt).to_string()
421 );
422 }
423 }
424
425 #[test]
426 fn test_get_column_defs() {
427 let entity = setup();
428
429 for (i, elem) in entity.get_column_defs().into_iter().enumerate() {
430 assert_eq!(elem.to_string(), entity.columns[i].get_def().to_string());
431 }
432 }
433
434 #[test]
435 fn test_get_primary_key_names_snake_case() {
436 let entity = setup();
437
438 for (i, elem) in entity
439 .get_primary_key_names_snake_case()
440 .into_iter()
441 .enumerate()
442 {
443 assert_eq!(elem, entity.primary_keys[i].get_name_snake_case());
444 }
445 }
446
447 #[test]
448 fn test_get_primary_key_names_camel_case() {
449 let entity = setup();
450
451 for (i, elem) in entity
452 .get_primary_key_names_camel_case()
453 .into_iter()
454 .enumerate()
455 {
456 assert_eq!(elem, entity.primary_keys[i].get_name_camel_case());
457 }
458 }
459
460 #[test]
461 fn test_get_relation_module_name() {
462 let entity = setup();
463
464 for (i, elem) in entity.get_relation_module_name().into_iter().enumerate() {
465 assert_eq!(elem, entity.relations[i].get_module_name());
466 }
467 }
468
469 #[test]
470 fn test_get_relation_enum_name() {
471 let entity = setup();
472
473 for (i, elem) in entity.get_relation_enum_name().into_iter().enumerate() {
474 assert_eq!(elem, entity.relations[i].get_enum_name());
475 }
476 }
477
478 #[test]
479 fn test_get_relation_defs() {
480 let entity = setup();
481
482 for (i, elem) in entity.get_relation_defs().into_iter().enumerate() {
483 assert_eq!(elem.to_string(), entity.relations[i].get_def().to_string());
484 }
485 }
486
487 #[test]
488 fn test_get_relation_attrs() {
489 let entity = setup();
490
491 for (i, elem) in entity.get_relation_attrs().into_iter().enumerate() {
492 assert_eq!(
493 elem.to_string(),
494 entity.relations[i].get_attrs().to_string()
495 );
496 }
497 }
498
499 #[test]
500 fn test_get_primary_key_auto_increment() {
501 let mut entity = setup();
502
503 assert_eq!(
504 entity.get_primary_key_auto_increment(),
505 format_ident!("{}", false)
506 );
507
508 entity.columns[0].auto_increment = true;
509 assert_eq!(
510 entity.get_primary_key_auto_increment(),
511 format_ident!("{}", true)
512 );
513 }
514
515 #[test]
516 fn test_get_primary_key_rs_type() {
517 let entity = setup();
518 let opt = Default::default();
519
520 assert_eq!(
521 entity.get_primary_key_rs_type(&opt).to_string(),
522 entity.columns[0].get_rs_type(&opt).to_string()
523 );
524 }
525
526 #[test]
527 fn test_get_conjunct_relations_via_snake_case() {
528 let entity = setup();
529
530 for (i, elem) in entity
531 .get_conjunct_relations_via_snake_case()
532 .into_iter()
533 .enumerate()
534 {
535 assert_eq!(elem, entity.conjunct_relations[i].get_via_snake_case());
536 }
537 }
538
539 #[test]
540 fn test_get_conjunct_relations_to_snake_case() {
541 let entity = setup();
542
543 for (i, elem) in entity
544 .get_conjunct_relations_to_snake_case()
545 .into_iter()
546 .enumerate()
547 {
548 assert_eq!(elem, entity.conjunct_relations[i].get_to_snake_case());
549 }
550 }
551
552 #[test]
553 fn test_get_conjunct_relations_to_upper_camel_case() {
554 let entity = setup();
555
556 for (i, elem) in entity
557 .get_conjunct_relations_to_upper_camel_case()
558 .into_iter()
559 .enumerate()
560 {
561 assert_eq!(elem, entity.conjunct_relations[i].get_to_upper_camel_case());
562 }
563 }
564
565 #[test]
566 fn test_get_eq_needed() {
567 let entity = setup();
568 let expected = quote! {, Eq};
569
570 assert_eq!(entity.get_eq_needed().to_string(), expected.to_string());
571 }
572}