1use crate::{
2 ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3 RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12 pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13 let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14 let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15 let mut entities = BTreeMap::new();
16 for table_create in table_create_stmts.into_iter() {
17 let table_name = match table_create.get_table_name() {
18 Some(table_ref) => table_ref.sea_orm_table().to_string(),
19 None => {
20 return Err(Error::TransformError(
21 "Table name should not be empty".into(),
22 ));
23 }
24 };
25 let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26 let mut columns: Vec<Column> = table_create
27 .get_columns()
28 .iter()
29 .map(|col_def| {
30 let primary_key = col_def.get_column_spec().primary_key;
31 if primary_key {
32 primary_keys.push(PrimaryKey {
33 name: col_def.get_column_name(),
34 });
35 }
36 col_def.into()
37 })
38 .map(|mut col: Column| {
39 col.unique = table_create
40 .get_indexes()
41 .iter()
42 .filter(|index| index.is_unique_key())
43 .map(|index| index.get_index_spec().get_column_names())
44 .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45 .count()
46 > 0;
47 col
48 })
49 .inspect(|col| {
50 if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51 {
52 enums.insert(
53 name.to_string(),
54 ActiveEnum {
55 enum_name: name.clone(),
56 values: variants.clone(),
57 },
58 );
59 }
60 })
61 .collect();
62 for index in table_create.get_indexes().iter() {
63 if index.is_unique_key() {
64 let col_names = index.get_index_spec().get_column_names();
65 if col_names.len() > 1 {
66 if let Some(mut key_name) = index.get_index_spec().get_name() {
67 if let Some((_, suffix)) = key_name.rsplit_once('-') {
68 key_name = suffix;
69 }
70 for col_name in col_names {
71 for column in columns.iter_mut() {
72 if column.name == col_name {
73 column.unique_key = Some(key_name.to_owned());
74 }
75 }
76 }
77 }
78 }
79 }
80 }
81 let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
82 let relations: Vec<Relation> = table_create
83 .get_foreign_key_create_stmts()
84 .iter()
85 .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
86 .map(|tbl_fk| {
87 let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
88 if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
89 if *count == 0 {
90 *count = 1;
91 }
92 *count += 1;
93 } else {
94 ref_table_counts.insert(ref_tbl, 0);
95 };
96 tbl_fk.into()
97 })
98 .collect::<Vec<_>>()
99 .into_iter()
100 .rev()
101 .map(|mut rel: Relation| {
102 rel.self_referencing = rel.ref_table == table_name;
103 if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
104 rel.num_suffix = *count;
105 if *count > 0 {
106 *count -= 1;
107 }
108 }
109 rel
110 })
111 .rev()
112 .collect();
113 primary_keys.extend(
114 table_create
115 .get_indexes()
116 .iter()
117 .filter(|index| index.is_primary_key())
118 .flat_map(|index| {
119 index
120 .get_index_spec()
121 .get_column_names()
122 .into_iter()
123 .map(|name| PrimaryKey { name })
124 .collect::<Vec<_>>()
125 }),
126 );
127 let entity = Entity {
128 table_name: table_name.clone(),
129 columns,
130 relations: relations.clone(),
131 conjunct_relations: vec![],
132 primary_keys,
133 };
134 entities.insert(table_name.clone(), entity.clone());
135 for mut rel in relations.into_iter() {
136 if rel.self_referencing {
138 continue;
139 }
140 if rel.num_suffix > 0 {
143 continue;
144 }
145 let ref_table = rel.ref_table;
146 let mut unique = true;
147 for column in rel.columns.iter() {
148 if !entity
149 .columns
150 .iter()
151 .filter(|col| col.unique)
152 .any(|col| col.name.as_str() == column)
153 {
154 unique = false;
155 break;
156 }
157 }
158 if rel.columns.len() == entity.primary_keys.len() {
159 let mut count_pk = 0;
160 for primary_key in entity.primary_keys.iter() {
161 if rel.columns.contains(&primary_key.name) {
162 count_pk += 1;
163 }
164 }
165 if count_pk == entity.primary_keys.len() {
166 unique = true;
167 }
168 }
169 let rel_type = if unique {
170 RelationType::HasOne
171 } else {
172 RelationType::HasMany
173 };
174 rel.rel_type = rel_type;
175 rel.ref_table = table_name.to_string();
176 rel.columns = Vec::new();
177 rel.ref_columns = Vec::new();
178 if let Some(vec) = inverse_relations.get_mut(&ref_table) {
179 vec.push(rel);
180 } else {
181 inverse_relations.insert(ref_table, vec![rel]);
182 }
183 }
184 }
185 for (tbl_name, relations) in inverse_relations.into_iter() {
186 if let Some(entity) = entities.get_mut(&tbl_name) {
187 for relation in relations.into_iter() {
188 let duplicate_relation = entity
189 .relations
190 .iter()
191 .any(|rel| rel.ref_table == relation.ref_table);
192 if !duplicate_relation {
193 entity.relations.push(relation);
194 }
195 }
196 }
197 }
198 for table_name in entities.clone().keys() {
199 let relations = match entities.get(table_name) {
200 Some(entity) => {
201 let is_conjunct_relation =
202 entity.relations.len() == 2 && entity.primary_keys.len() == 2;
203 if !is_conjunct_relation {
204 continue;
205 }
206 entity.relations.clone()
207 }
208 None => unreachable!(),
209 };
210 for (i, rel) in relations.iter().enumerate() {
211 let another_rel = relations.get((i == 0) as usize).unwrap();
212 if let Some(entity) = entities.get_mut(&rel.ref_table) {
213 let conjunct_relation = ConjunctRelation {
214 via: table_name.clone(),
215 to: another_rel.ref_table.clone(),
216 };
217 entity.conjunct_relations.push(conjunct_relation);
218 }
219 }
220 }
221 Ok(EntityWriter {
222 entities: entities
223 .into_values()
224 .map(|mut v| {
225 let duplicated_to: Vec<_> = v
227 .conjunct_relations
228 .iter()
229 .fold(HashMap::new(), |mut acc, conjunct_relation| {
230 acc.entry(conjunct_relation.to.clone())
231 .and_modify(|c| *c += 1)
232 .or_insert(1);
233 acc
234 })
235 .into_iter()
236 .filter(|(_, v)| v > &1)
237 .map(|(k, _)| k)
238 .collect();
239 v.conjunct_relations
240 .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
241
242 v.relations.iter_mut().for_each(|relation| {
245 if v.conjunct_relations
246 .iter()
247 .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
248 {
249 relation.impl_related = false;
250 }
251 });
252
253 v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
255 v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
256 v
257 })
258 .collect(),
259 enums,
260 })
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use pretty_assertions::assert_eq;
268 use proc_macro2::TokenStream;
269 use sea_orm::{DbBackend, Schema};
270 use std::{
271 error::Error,
272 io::{self, BufRead, BufReader},
273 };
274
275 #[test]
276 fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
277 use crate::tests_cfg::duplicated_many_to_many_paths::*;
278 let schema = Schema::new(DbBackend::Postgres);
279
280 validate_compact_entities(
281 vec![
282 schema.create_table_from_entity(bills::Entity),
283 schema.create_table_from_entity(users::Entity),
284 schema.create_table_from_entity(users_saved_bills::Entity),
285 schema.create_table_from_entity(users_votes::Entity),
286 ],
287 vec![
288 (
289 "bills",
290 include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
291 ),
292 (
293 "users",
294 include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
295 ),
296 (
297 "users_saved_bills",
298 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
299 ),
300 (
301 "users_votes",
302 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
303 ),
304 ],
305 )
306 }
307
308 #[test]
309 fn many_to_many() -> Result<(), Box<dyn Error>> {
310 use crate::tests_cfg::many_to_many::*;
311 let schema = Schema::new(DbBackend::Postgres);
312
313 validate_compact_entities(
314 vec![
315 schema.create_table_from_entity(bills::Entity),
316 schema.create_table_from_entity(users::Entity),
317 schema.create_table_from_entity(users_votes::Entity),
318 ],
319 vec![
320 ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
321 ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
322 (
323 "users_votes",
324 include_str!("../tests_cfg/many_to_many/users_votes.rs"),
325 ),
326 ],
327 )
328 }
329
330 #[test]
331 fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
332 use crate::tests_cfg::many_to_many_multiple::*;
333 let schema = Schema::new(DbBackend::Postgres);
334
335 validate_compact_entities(
336 vec![
337 schema.create_table_from_entity(bills::Entity),
338 schema.create_table_from_entity(users::Entity),
339 schema.create_table_from_entity(users_votes::Entity),
340 ],
341 vec![
342 (
343 "bills",
344 include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
345 ),
346 (
347 "users",
348 include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
349 ),
350 (
351 "users_votes",
352 include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
353 ),
354 ],
355 )
356 }
357
358 #[test]
359 fn self_referencing() -> Result<(), Box<dyn Error>> {
360 use crate::tests_cfg::self_referencing::*;
361 let schema = Schema::new(DbBackend::Postgres);
362
363 validate_compact_entities(
364 vec![
365 schema.create_table_from_entity(bills::Entity),
366 schema.create_table_from_entity(users::Entity),
367 ],
368 vec![
369 (
370 "bills",
371 include_str!("../tests_cfg/self_referencing/bills.rs"),
372 ),
373 (
374 "users",
375 include_str!("../tests_cfg/self_referencing/users.rs"),
376 ),
377 ],
378 )
379 }
380
381 #[test]
382 fn test_indexes_transform() -> Result<(), Box<dyn Error>> {
383 let schema = Schema::new(DbBackend::Postgres);
384
385 validate_compact_entities(
386 vec![
387 schema.create_table_with_index_from_entity(
388 crate::tests_cfg::compact::indexes::Entity,
389 ),
390 ],
391 vec![("indexes", include_str!("../tests_cfg/compact/indexes.rs"))],
392 )?;
393
394 validate_dense_entities(
395 vec![
396 schema
397 .create_table_with_index_from_entity(crate::tests_cfg::dense::indexes::Entity),
398 ],
399 vec![("indexes", include_str!("../tests_cfg/dense/indexes.rs"))],
400 )?;
401
402 Ok(())
403 }
404
405 macro_rules! validate_entities_fn {
406 ($fn_name: ident, $method: ident) => {
407 fn $fn_name(
408 table_create_stmts: Vec<TableCreateStatement>,
409 files: Vec<(&str, &str)>,
410 ) -> Result<(), Box<dyn Error>> {
411 let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
412 .entities
413 .into_iter()
414 .map(|entity| (entity.table_name.clone(), entity))
415 .collect();
416
417 for (entity_name, file_content) in files {
418 let entity = entities
419 .get(entity_name)
420 .expect("Forget to add entity to the list");
421
422 assert_eq!(
423 parse_from_file(file_content.as_bytes())?.to_string(),
424 EntityWriter::$method(
425 entity,
426 &crate::WithSerde::None,
427 &Default::default(),
428 &None,
429 false,
430 false,
431 &Default::default(),
432 &Default::default(),
433 &Default::default(),
434 false,
435 true,
436 )
437 .into_iter()
438 .skip(1)
439 .fold(TokenStream::new(), |mut acc, tok| {
440 acc.extend(tok);
441 acc
442 })
443 .to_string()
444 );
445 }
446
447 Ok(())
448 }
449 };
450 }
451
452 validate_entities_fn!(validate_compact_entities, gen_compact_code_blocks);
453 validate_entities_fn!(validate_dense_entities, gen_dense_code_blocks);
454
455 fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
456 where
457 R: io::Read,
458 {
459 let mut reader = BufReader::new(inner);
460 let mut lines: Vec<String> = Vec::new();
461
462 reader.read_until(b';', &mut Vec::new())?;
463
464 let mut line = String::new();
465 while reader.read_line(&mut line)? > 0 {
466 lines.push(line.to_owned());
467 line.clear();
468 }
469 let content = lines.join("");
470 Ok(content.parse().unwrap())
471 }
472}