1use crate::{
2 ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3 RelationType,
4};
5use sea_query::TableCreateStatement;
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12 pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13 let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14 let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15 let mut entities = BTreeMap::new();
16 for table_create in table_create_stmts.into_iter() {
17 let table_name = match table_create.get_table_name() {
18 Some(table_ref) => table_ref.sea_orm_table().to_string(),
19 None => {
20 return Err(Error::TransformError(
21 "Table name should not be empty".into(),
22 ));
23 }
24 };
25 let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26 let columns: Vec<Column> = table_create
27 .get_columns()
28 .iter()
29 .map(|col_def| {
30 let primary_key = col_def.get_column_spec().primary_key;
31 if primary_key {
32 primary_keys.push(PrimaryKey {
33 name: col_def.get_column_name(),
34 });
35 }
36 col_def.into()
37 })
38 .map(|mut col: Column| {
39 col.unique = table_create
40 .get_indexes()
41 .iter()
42 .filter(|index| index.is_unique_key())
43 .map(|index| index.get_index_spec().get_column_names())
44 .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
45 .count()
46 > 0;
47 col
48 })
49 .inspect(|col| {
50 if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
51 {
52 enums.insert(
53 name.to_string(),
54 ActiveEnum {
55 enum_name: name.clone(),
56 values: variants.clone(),
57 },
58 );
59 }
60 })
61 .collect();
62 let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
63 let relations: Vec<Relation> = table_create
64 .get_foreign_key_create_stmts()
65 .iter()
66 .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
67 .map(|tbl_fk| {
68 let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
69 if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
70 if *count == 0 {
71 *count = 1;
72 }
73 *count += 1;
74 } else {
75 ref_table_counts.insert(ref_tbl, 0);
76 };
77 tbl_fk.into()
78 })
79 .collect::<Vec<_>>()
80 .into_iter()
81 .rev()
82 .map(|mut rel: Relation| {
83 rel.self_referencing = rel.ref_table == table_name;
84 if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
85 rel.num_suffix = *count;
86 if *count > 0 {
87 *count -= 1;
88 }
89 }
90 rel
91 })
92 .rev()
93 .collect();
94 primary_keys.extend(
95 table_create
96 .get_indexes()
97 .iter()
98 .filter(|index| index.is_primary_key())
99 .flat_map(|index| {
100 index
101 .get_index_spec()
102 .get_column_names()
103 .into_iter()
104 .map(|name| PrimaryKey { name })
105 .collect::<Vec<_>>()
106 }),
107 );
108 let entity = Entity {
109 table_name: table_name.clone(),
110 columns,
111 relations: relations.clone(),
112 conjunct_relations: vec![],
113 primary_keys,
114 };
115 entities.insert(table_name.clone(), entity.clone());
116 for mut rel in relations.into_iter() {
117 if rel.self_referencing {
119 continue;
120 }
121 if rel.num_suffix > 0 {
124 continue;
125 }
126 let ref_table = rel.ref_table;
127 let mut unique = true;
128 for column in rel.columns.iter() {
129 if !entity
130 .columns
131 .iter()
132 .filter(|col| col.unique)
133 .any(|col| col.name.as_str() == column)
134 {
135 unique = false;
136 break;
137 }
138 }
139 if rel.columns.len() == entity.primary_keys.len() {
140 let mut count_pk = 0;
141 for primary_key in entity.primary_keys.iter() {
142 if rel.columns.contains(&primary_key.name) {
143 count_pk += 1;
144 }
145 }
146 if count_pk == entity.primary_keys.len() {
147 unique = true;
148 }
149 }
150 let rel_type = if unique {
151 RelationType::HasOne
152 } else {
153 RelationType::HasMany
154 };
155 rel.rel_type = rel_type;
156 rel.ref_table = table_name.to_string();
157 rel.columns = Vec::new();
158 rel.ref_columns = Vec::new();
159 if let Some(vec) = inverse_relations.get_mut(&ref_table) {
160 vec.push(rel);
161 } else {
162 inverse_relations.insert(ref_table, vec![rel]);
163 }
164 }
165 }
166 for (tbl_name, relations) in inverse_relations.into_iter() {
167 if let Some(entity) = entities.get_mut(&tbl_name) {
168 for relation in relations.into_iter() {
169 let duplicate_relation = entity
170 .relations
171 .iter()
172 .any(|rel| rel.ref_table == relation.ref_table);
173 if !duplicate_relation {
174 entity.relations.push(relation);
175 }
176 }
177 }
178 }
179 for table_name in entities.clone().keys() {
180 let relations = match entities.get(table_name) {
181 Some(entity) => {
182 let is_conjunct_relation =
183 entity.relations.len() == 2 && entity.primary_keys.len() == 2;
184 if !is_conjunct_relation {
185 continue;
186 }
187 entity.relations.clone()
188 }
189 None => unreachable!(),
190 };
191 for (i, rel) in relations.iter().enumerate() {
192 let another_rel = relations.get((i == 0) as usize).unwrap();
193 if let Some(entity) = entities.get_mut(&rel.ref_table) {
194 let conjunct_relation = ConjunctRelation {
195 via: table_name.clone(),
196 to: another_rel.ref_table.clone(),
197 };
198 entity.conjunct_relations.push(conjunct_relation);
199 }
200 }
201 }
202 Ok(EntityWriter {
203 entities: entities
204 .into_values()
205 .map(|mut v| {
206 let duplicated_to: Vec<_> = v
208 .conjunct_relations
209 .iter()
210 .fold(HashMap::new(), |mut acc, conjunct_relation| {
211 acc.entry(conjunct_relation.to.clone())
212 .and_modify(|c| *c += 1)
213 .or_insert(1);
214 acc
215 })
216 .into_iter()
217 .filter(|(_, v)| v > &1)
218 .map(|(k, _)| k)
219 .collect();
220 v.conjunct_relations
221 .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
222
223 v.relations.iter_mut().for_each(|relation| {
226 if v.conjunct_relations
227 .iter()
228 .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
229 {
230 relation.impl_related = false;
231 }
232 });
233
234 v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
236 v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
237 v
238 })
239 .collect(),
240 enums,
241 })
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use pretty_assertions::assert_eq;
249 use proc_macro2::TokenStream;
250 use sea_orm::{DbBackend, Schema};
251 use std::{
252 error::Error,
253 io::{self, BufRead, BufReader},
254 };
255
256 #[test]
257 fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
258 use crate::tests_cfg::duplicated_many_to_many_paths::*;
259 let schema = Schema::new(DbBackend::Postgres);
260
261 validate_compact_entities(
262 vec![
263 schema.create_table_from_entity(bills::Entity),
264 schema.create_table_from_entity(users::Entity),
265 schema.create_table_from_entity(users_saved_bills::Entity),
266 schema.create_table_from_entity(users_votes::Entity),
267 ],
268 vec![
269 (
270 "bills",
271 include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
272 ),
273 (
274 "users",
275 include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
276 ),
277 (
278 "users_saved_bills",
279 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
280 ),
281 (
282 "users_votes",
283 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
284 ),
285 ],
286 )
287 }
288
289 #[test]
290 fn many_to_many() -> Result<(), Box<dyn Error>> {
291 use crate::tests_cfg::many_to_many::*;
292 let schema = Schema::new(DbBackend::Postgres);
293
294 validate_compact_entities(
295 vec![
296 schema.create_table_from_entity(bills::Entity),
297 schema.create_table_from_entity(users::Entity),
298 schema.create_table_from_entity(users_votes::Entity),
299 ],
300 vec![
301 ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
302 ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
303 (
304 "users_votes",
305 include_str!("../tests_cfg/many_to_many/users_votes.rs"),
306 ),
307 ],
308 )
309 }
310
311 #[test]
312 fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
313 use crate::tests_cfg::many_to_many_multiple::*;
314 let schema = Schema::new(DbBackend::Postgres);
315
316 validate_compact_entities(
317 vec![
318 schema.create_table_from_entity(bills::Entity),
319 schema.create_table_from_entity(users::Entity),
320 schema.create_table_from_entity(users_votes::Entity),
321 ],
322 vec![
323 (
324 "bills",
325 include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
326 ),
327 (
328 "users",
329 include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
330 ),
331 (
332 "users_votes",
333 include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
334 ),
335 ],
336 )
337 }
338
339 #[test]
340 fn self_referencing() -> Result<(), Box<dyn Error>> {
341 use crate::tests_cfg::self_referencing::*;
342 let schema = Schema::new(DbBackend::Postgres);
343
344 validate_compact_entities(
345 vec![
346 schema.create_table_from_entity(bills::Entity),
347 schema.create_table_from_entity(users::Entity),
348 ],
349 vec![
350 (
351 "bills",
352 include_str!("../tests_cfg/self_referencing/bills.rs"),
353 ),
354 (
355 "users",
356 include_str!("../tests_cfg/self_referencing/users.rs"),
357 ),
358 ],
359 )
360 }
361
362 fn validate_compact_entities(
363 table_create_stmts: Vec<TableCreateStatement>,
364 files: Vec<(&str, &str)>,
365 ) -> Result<(), Box<dyn Error>> {
366 let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
367 .entities
368 .into_iter()
369 .map(|entity| (entity.table_name.clone(), entity))
370 .collect();
371
372 for (entity_name, file_content) in files {
373 let entity = entities
374 .get(entity_name)
375 .expect("Forget to add entity to the list");
376
377 assert_eq!(
378 parse_from_file(file_content.as_bytes())?.to_string(),
379 EntityWriter::gen_compact_code_blocks(
380 entity,
381 &crate::WithSerde::None,
382 &crate::DateTimeCrate::Chrono,
383 &None,
384 false,
385 false,
386 &Default::default(),
387 &Default::default(),
388 &Default::default(),
389 false,
390 true,
391 )
392 .into_iter()
393 .skip(1)
394 .fold(TokenStream::new(), |mut acc, tok| {
395 acc.extend(tok);
396 acc
397 })
398 .to_string()
399 );
400 }
401
402 Ok(())
403 }
404
405 fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
406 where
407 R: io::Read,
408 {
409 let mut reader = BufReader::new(inner);
410 let mut lines: Vec<String> = Vec::new();
411
412 reader.read_until(b';', &mut Vec::new())?;
413
414 let mut line = String::new();
415 while reader.read_line(&mut line)? > 0 {
416 lines.push(line.to_owned());
417 line.clear();
418 }
419 let content = lines.join("");
420 Ok(content.parse().unwrap())
421 }
422}