1use crate::{
2 ActiveEnum, Column, ConjunctRelation, Entity, EntityWriter, Error, PrimaryKey, Relation,
3 RelationType,
4};
5use sea_query::{ColumnSpec, TableCreateStatement};
6use std::collections::{BTreeMap, HashMap};
7
8#[derive(Clone, Debug)]
9pub struct EntityTransformer;
10
11impl EntityTransformer {
12 pub fn transform(table_create_stmts: Vec<TableCreateStatement>) -> Result<EntityWriter, Error> {
13 let mut enums: BTreeMap<String, ActiveEnum> = BTreeMap::new();
14 let mut inverse_relations: BTreeMap<String, Vec<Relation>> = BTreeMap::new();
15 let mut entities = BTreeMap::new();
16 for table_create in table_create_stmts.into_iter() {
17 let table_name = match table_create.get_table_name() {
18 Some(table_ref) => table_ref.sea_orm_table().to_string(),
19 None => {
20 return Err(Error::TransformError(
21 "Table name should not be empty".into(),
22 ));
23 }
24 };
25 let mut primary_keys: Vec<PrimaryKey> = Vec::new();
26 let columns: Vec<Column> = table_create
27 .get_columns()
28 .iter()
29 .map(|col_def| {
30 let primary_key = col_def
31 .get_column_spec()
32 .iter()
33 .any(|spec| matches!(spec, ColumnSpec::PrimaryKey));
34 if primary_key {
35 primary_keys.push(PrimaryKey {
36 name: col_def.get_column_name(),
37 });
38 }
39 col_def.into()
40 })
41 .map(|mut col: Column| {
42 col.unique = table_create
43 .get_indexes()
44 .iter()
45 .filter(|index| index.is_unique_key())
46 .map(|index| index.get_index_spec().get_column_names())
47 .filter(|col_names| col_names.len() == 1 && col_names[0] == col.name)
48 .count()
49 > 0;
50 col
51 })
52 .inspect(|col| {
53 if let sea_query::ColumnType::Enum { name, variants } = col.get_inner_col_type()
54 {
55 enums.insert(
56 name.to_string(),
57 ActiveEnum {
58 enum_name: name.clone(),
59 values: variants.clone(),
60 },
61 );
62 }
63 })
64 .collect();
65 let mut ref_table_counts: BTreeMap<String, usize> = BTreeMap::new();
66 let relations: Vec<Relation> = table_create
67 .get_foreign_key_create_stmts()
68 .iter()
69 .map(|fk_create_stmt| fk_create_stmt.get_foreign_key())
70 .map(|tbl_fk| {
71 let ref_tbl = tbl_fk.get_ref_table().unwrap().sea_orm_table().to_string();
72 if let Some(count) = ref_table_counts.get_mut(&ref_tbl) {
73 if *count == 0 {
74 *count = 1;
75 }
76 *count += 1;
77 } else {
78 ref_table_counts.insert(ref_tbl, 0);
79 };
80 tbl_fk.into()
81 })
82 .collect::<Vec<_>>()
83 .into_iter()
84 .rev()
85 .map(|mut rel: Relation| {
86 rel.self_referencing = rel.ref_table == table_name;
87 if let Some(count) = ref_table_counts.get_mut(&rel.ref_table) {
88 rel.num_suffix = *count;
89 if *count > 0 {
90 *count -= 1;
91 }
92 }
93 rel
94 })
95 .rev()
96 .collect();
97 primary_keys.extend(
98 table_create
99 .get_indexes()
100 .iter()
101 .filter(|index| index.is_primary_key())
102 .flat_map(|index| {
103 index
104 .get_index_spec()
105 .get_column_names()
106 .into_iter()
107 .map(|name| PrimaryKey { name })
108 .collect::<Vec<_>>()
109 }),
110 );
111 let entity = Entity {
112 table_name: table_name.clone(),
113 columns,
114 relations: relations.clone(),
115 conjunct_relations: vec![],
116 primary_keys,
117 };
118 entities.insert(table_name.clone(), entity.clone());
119 for mut rel in relations.into_iter() {
120 if rel.self_referencing {
122 continue;
123 }
124 if rel.num_suffix > 0 {
127 continue;
128 }
129 let ref_table = rel.ref_table;
130 let mut unique = true;
131 for column in rel.columns.iter() {
132 if !entity
133 .columns
134 .iter()
135 .filter(|col| col.unique)
136 .any(|col| col.name.as_str() == column)
137 {
138 unique = false;
139 break;
140 }
141 }
142 if rel.columns.len() == entity.primary_keys.len() {
143 let mut count_pk = 0;
144 for primary_key in entity.primary_keys.iter() {
145 if rel.columns.contains(&primary_key.name) {
146 count_pk += 1;
147 }
148 }
149 if count_pk == entity.primary_keys.len() {
150 unique = true;
151 }
152 }
153 let rel_type = if unique {
154 RelationType::HasOne
155 } else {
156 RelationType::HasMany
157 };
158 rel.rel_type = rel_type;
159 rel.ref_table = table_name.to_string();
160 rel.columns = Vec::new();
161 rel.ref_columns = Vec::new();
162 if let Some(vec) = inverse_relations.get_mut(&ref_table) {
163 vec.push(rel);
164 } else {
165 inverse_relations.insert(ref_table, vec![rel]);
166 }
167 }
168 }
169 for (tbl_name, relations) in inverse_relations.into_iter() {
170 if let Some(entity) = entities.get_mut(&tbl_name) {
171 for relation in relations.into_iter() {
172 let duplicate_relation = entity
173 .relations
174 .iter()
175 .any(|rel| rel.ref_table == relation.ref_table);
176 if !duplicate_relation {
177 entity.relations.push(relation);
178 }
179 }
180 }
181 }
182 for table_name in entities.clone().keys() {
183 let relations = match entities.get(table_name) {
184 Some(entity) => {
185 let is_conjunct_relation =
186 entity.relations.len() == 2 && entity.primary_keys.len() == 2;
187 if !is_conjunct_relation {
188 continue;
189 }
190 entity.relations.clone()
191 }
192 None => unreachable!(),
193 };
194 for (i, rel) in relations.iter().enumerate() {
195 let another_rel = relations.get((i == 0) as usize).unwrap();
196 if let Some(entity) = entities.get_mut(&rel.ref_table) {
197 let conjunct_relation = ConjunctRelation {
198 via: table_name.clone(),
199 to: another_rel.ref_table.clone(),
200 };
201 entity.conjunct_relations.push(conjunct_relation);
202 }
203 }
204 }
205 Ok(EntityWriter {
206 entities: entities
207 .into_values()
208 .map(|mut v| {
209 let duplicated_to: Vec<_> = v
211 .conjunct_relations
212 .iter()
213 .fold(HashMap::new(), |mut acc, conjunct_relation| {
214 acc.entry(conjunct_relation.to.clone())
215 .and_modify(|c| *c += 1)
216 .or_insert(1);
217 acc
218 })
219 .into_iter()
220 .filter(|(_, v)| v > &1)
221 .map(|(k, _)| k)
222 .collect();
223 v.conjunct_relations
224 .retain(|conjunct_relation| !duplicated_to.contains(&conjunct_relation.to));
225
226 v.relations.iter_mut().for_each(|relation| {
229 if v.conjunct_relations
230 .iter()
231 .any(|conjunct_relation| conjunct_relation.to == relation.ref_table)
232 {
233 relation.impl_related = false;
234 }
235 });
236
237 v.relations.sort_by(|a, b| a.ref_table.cmp(&b.ref_table));
239 v.conjunct_relations.sort_by(|a, b| a.to.cmp(&b.to));
240 v
241 })
242 .collect(),
243 enums,
244 })
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use pretty_assertions::assert_eq;
252 use proc_macro2::TokenStream;
253 use sea_orm::{DbBackend, Schema};
254 use std::{
255 error::Error,
256 io::{self, BufRead, BufReader},
257 };
258
259 #[test]
260 fn duplicated_many_to_many_paths() -> Result<(), Box<dyn Error>> {
261 use crate::tests_cfg::duplicated_many_to_many_paths::*;
262 let schema = Schema::new(DbBackend::Postgres);
263
264 validate_compact_entities(
265 vec![
266 schema.create_table_from_entity(bills::Entity),
267 schema.create_table_from_entity(users::Entity),
268 schema.create_table_from_entity(users_saved_bills::Entity),
269 schema.create_table_from_entity(users_votes::Entity),
270 ],
271 vec![
272 (
273 "bills",
274 include_str!("../tests_cfg/duplicated_many_to_many_paths/bills.rs"),
275 ),
276 (
277 "users",
278 include_str!("../tests_cfg/duplicated_many_to_many_paths/users.rs"),
279 ),
280 (
281 "users_saved_bills",
282 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_saved_bills.rs"),
283 ),
284 (
285 "users_votes",
286 include_str!("../tests_cfg/duplicated_many_to_many_paths/users_votes.rs"),
287 ),
288 ],
289 )
290 }
291
292 #[test]
293 fn many_to_many() -> Result<(), Box<dyn Error>> {
294 use crate::tests_cfg::many_to_many::*;
295 let schema = Schema::new(DbBackend::Postgres);
296
297 validate_compact_entities(
298 vec![
299 schema.create_table_from_entity(bills::Entity),
300 schema.create_table_from_entity(users::Entity),
301 schema.create_table_from_entity(users_votes::Entity),
302 ],
303 vec![
304 ("bills", include_str!("../tests_cfg/many_to_many/bills.rs")),
305 ("users", include_str!("../tests_cfg/many_to_many/users.rs")),
306 (
307 "users_votes",
308 include_str!("../tests_cfg/many_to_many/users_votes.rs"),
309 ),
310 ],
311 )
312 }
313
314 #[test]
315 fn many_to_many_multiple() -> Result<(), Box<dyn Error>> {
316 use crate::tests_cfg::many_to_many_multiple::*;
317 let schema = Schema::new(DbBackend::Postgres);
318
319 validate_compact_entities(
320 vec![
321 schema.create_table_from_entity(bills::Entity),
322 schema.create_table_from_entity(users::Entity),
323 schema.create_table_from_entity(users_votes::Entity),
324 ],
325 vec![
326 (
327 "bills",
328 include_str!("../tests_cfg/many_to_many_multiple/bills.rs"),
329 ),
330 (
331 "users",
332 include_str!("../tests_cfg/many_to_many_multiple/users.rs"),
333 ),
334 (
335 "users_votes",
336 include_str!("../tests_cfg/many_to_many_multiple/users_votes.rs"),
337 ),
338 ],
339 )
340 }
341
342 #[test]
343 fn self_referencing() -> Result<(), Box<dyn Error>> {
344 use crate::tests_cfg::self_referencing::*;
345 let schema = Schema::new(DbBackend::Postgres);
346
347 validate_compact_entities(
348 vec![
349 schema.create_table_from_entity(bills::Entity),
350 schema.create_table_from_entity(users::Entity),
351 ],
352 vec![
353 (
354 "bills",
355 include_str!("../tests_cfg/self_referencing/bills.rs"),
356 ),
357 (
358 "users",
359 include_str!("../tests_cfg/self_referencing/users.rs"),
360 ),
361 ],
362 )
363 }
364
365 fn validate_compact_entities(
366 table_create_stmts: Vec<TableCreateStatement>,
367 files: Vec<(&str, &str)>,
368 ) -> Result<(), Box<dyn Error>> {
369 let entities: HashMap<_, _> = EntityTransformer::transform(table_create_stmts)?
370 .entities
371 .into_iter()
372 .map(|entity| (entity.table_name.clone(), entity))
373 .collect();
374
375 for (entity_name, file_content) in files {
376 let entity = entities
377 .get(entity_name)
378 .expect("Forget to add entity to the list");
379
380 assert_eq!(
381 parse_from_file(file_content.as_bytes())?.to_string(),
382 EntityWriter::gen_compact_code_blocks(
383 entity,
384 &crate::WithSerde::None,
385 &crate::DateTimeCrate::Chrono,
386 &None,
387 false,
388 false,
389 &Default::default(),
390 &Default::default(),
391 false,
392 true,
393 )
394 .into_iter()
395 .skip(1)
396 .fold(TokenStream::new(), |mut acc, tok| {
397 acc.extend(tok);
398 acc
399 })
400 .to_string()
401 );
402 }
403
404 Ok(())
405 }
406
407 fn parse_from_file<R>(inner: R) -> io::Result<TokenStream>
408 where
409 R: io::Read,
410 {
411 let mut reader = BufReader::new(inner);
412 let mut lines: Vec<String> = Vec::new();
413
414 reader.read_until(b';', &mut Vec::new())?;
415
416 let mut line = String::new();
417 while reader.read_line(&mut line)? > 0 {
418 lines.push(line.to_owned());
419 line.clear();
420 }
421 let content = lines.join("");
422 Ok(content.parse().unwrap())
423 }
424}