1use std::collections::{BTreeSet, HashSet};
2use std::fmt::Write;
3
4use sea_query::ColumnType;
5
6use crate::{Entity, RelationType};
7
8use super::EntityWriter;
9
10impl EntityWriter {
11 pub fn generate_er_diagram(&self) -> String {
12 let mut out = String::from("erDiagram\n");
13
14 let pk_sets: Vec<HashSet<&str>> = self
15 .entities
16 .iter()
17 .map(|e| e.primary_keys.iter().map(|pk| pk.name.as_str()).collect())
18 .collect();
19
20 let fk_sets: Vec<HashSet<&str>> = self
21 .entities
22 .iter()
23 .map(|e| {
24 e.relations
25 .iter()
26 .filter(|r| matches!(r.rel_type, RelationType::BelongsTo))
27 .flat_map(|r| r.columns.iter().map(String::as_str))
28 .collect()
29 })
30 .collect();
31
32 for (i, entity) in self.entities.iter().enumerate() {
33 write_entity_block(&mut out, entity, &pk_sets[i], &fk_sets[i]);
34 }
35
36 let mut emitted: BTreeSet<String> = BTreeSet::new();
37
38 for entity in &self.entities {
39 write_relations(&mut out, entity, &mut emitted);
40 }
41
42 out
43 }
44}
45
46fn write_entity_block(out: &mut String, entity: &Entity, pks: &HashSet<&str>, fks: &HashSet<&str>) {
47 let _ = writeln!(out, " {} {{", entity.table_name);
48
49 for col in &entity.columns {
50 let type_name = col_type_name(&col.col_type);
51 let is_pk = pks.contains(col.name.as_str());
52 let is_fk = fks.contains(col.name.as_str());
53 let is_uk = col.unique || col.unique_key.is_some();
54
55 let constraint = match (is_pk, is_fk, is_uk) {
56 (true, true, _) => " PK,FK",
57 (true, false, _) => " PK",
58 (false, true, true) => " FK,UK",
59 (false, true, false) => " FK",
60 (false, false, true) => " UK",
61 (false, false, false) => "",
62 };
63
64 let _ = writeln!(out, " {} {}{}", type_name, col.name, constraint);
65 }
66
67 let _ = writeln!(out, " }}");
68}
69
70fn write_relations(out: &mut String, entity: &Entity, emitted: &mut BTreeSet<String>) {
71 for rel in &entity.relations {
72 let (left, right, cardinality, label) = match rel.rel_type {
73 RelationType::BelongsTo => (
74 &entity.table_name,
75 &rel.ref_table,
76 "}o--||",
77 rel.columns.join(", "),
78 ),
79 RelationType::HasOne => continue,
80 RelationType::HasMany => continue,
81 };
82
83 let key = format!("{left} {cardinality} {right} : \"{label}\"");
84 if emitted.insert(key.clone()) {
85 let _ = writeln!(out, " {key}");
86 }
87 }
88
89 for conj in &entity.conjunct_relations {
90 let left = &entity.table_name;
91 let right = &conj.to;
92 let label = format!("[{}]", conj.via);
93
94 let key = if left <= right {
95 format!("{left} }}o--o{{ {right} : \"{label}\"")
96 } else {
97 format!("{right} }}o--o{{ {left} : \"{label}\"")
98 };
99
100 if emitted.insert(key.clone()) {
101 let _ = writeln!(out, " {key}");
102 }
103 }
104}
105
106fn col_type_name(col_type: &ColumnType) -> &str {
107 #[allow(unreachable_patterns)]
108 match col_type {
109 ColumnType::Char(_) => "char",
110 ColumnType::String(_) => "varchar",
111 ColumnType::Text => "text",
112 ColumnType::TinyInteger => "tinyint",
113 ColumnType::SmallInteger => "smallint",
114 ColumnType::Integer => "int",
115 ColumnType::BigInteger => "bigint",
116 ColumnType::TinyUnsigned => "tinyint_unsigned",
117 ColumnType::SmallUnsigned => "smallint_unsigned",
118 ColumnType::Unsigned => "int_unsigned",
119 ColumnType::BigUnsigned => "bigint_unsigned",
120 ColumnType::Float => "float",
121 ColumnType::Double => "double",
122 ColumnType::Decimal(_) => "decimal",
123 ColumnType::Money(_) => "money",
124 ColumnType::DateTime => "datetime",
125 ColumnType::Timestamp => "timestamp",
126 ColumnType::TimestampWithTimeZone => "timestamptz",
127 ColumnType::Time => "time",
128 ColumnType::Date => "date",
129 ColumnType::Year => "year",
130 ColumnType::Binary(_) | ColumnType::VarBinary(_) | ColumnType::Blob => "blob",
131 ColumnType::Boolean => "bool",
132 ColumnType::Json | ColumnType::JsonBinary => "json",
133 ColumnType::Uuid => "uuid",
134 ColumnType::Enum { .. } => "enum",
135 ColumnType::Array(_) => "array",
136 ColumnType::Vector(_) => "vector",
137 ColumnType::Bit(_) | ColumnType::VarBit(_) => "bit",
138 ColumnType::Cidr => "cidr",
139 ColumnType::Inet => "inet",
140 ColumnType::MacAddr => "macaddr",
141 ColumnType::LTree => "ltree",
142 ColumnType::Interval(_, _) => "interval",
143 ColumnType::Custom(_) => "custom",
144 _ => "unknown",
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use std::collections::BTreeMap;
151
152 use sea_query::{ColumnType, StringLen};
153
154 use crate::{
155 Column, ConjunctRelation, Entity, EntityWriter, PrimaryKey, Relation, RelationType,
156 };
157
158 fn setup_blog_schema() -> EntityWriter {
159 EntityWriter {
160 entities: vec![
161 Entity {
162 table_name: "user".to_owned(),
163 columns: vec![
164 Column {
165 name: "id".to_owned(),
166 col_type: ColumnType::Integer,
167 auto_increment: true,
168 not_null: true,
169 unique: false,
170 unique_key: None,
171 },
172 Column {
173 name: "name".to_owned(),
174 col_type: ColumnType::String(StringLen::N(255)),
175 auto_increment: false,
176 not_null: true,
177 unique: false,
178 unique_key: None,
179 },
180 Column {
181 name: "email".to_owned(),
182 col_type: ColumnType::String(StringLen::N(255)),
183 auto_increment: false,
184 not_null: true,
185 unique: true,
186 unique_key: None,
187 },
188 Column {
189 name: "parent_id".to_owned(),
190 col_type: ColumnType::Integer,
191 auto_increment: false,
192 not_null: false,
193 unique: false,
194 unique_key: None,
195 },
196 ],
197 relations: vec![
198 Relation {
199 ref_table: "post".to_owned(),
200 columns: vec![],
201 ref_columns: vec![],
202 rel_type: RelationType::HasMany,
203 on_delete: None,
204 on_update: None,
205 self_referencing: false,
206 num_suffix: 0,
207 impl_related: true,
208 },
209 Relation {
210 ref_table: "user".to_owned(),
211 columns: vec!["parent_id".to_owned()],
212 ref_columns: vec!["id".to_owned()],
213 rel_type: RelationType::BelongsTo,
214 on_delete: None,
215 on_update: None,
216 self_referencing: true,
217 num_suffix: 0,
218 impl_related: true,
219 },
220 ],
221 conjunct_relations: vec![],
222 primary_keys: vec![PrimaryKey {
223 name: "id".to_owned(),
224 }],
225 },
226 Entity {
227 table_name: "post".to_owned(),
228 columns: vec![
229 Column {
230 name: "id".to_owned(),
231 col_type: ColumnType::Integer,
232 auto_increment: true,
233 not_null: true,
234 unique: false,
235 unique_key: None,
236 },
237 Column {
238 name: "title".to_owned(),
239 col_type: ColumnType::Text,
240 auto_increment: false,
241 not_null: true,
242 unique: false,
243 unique_key: None,
244 },
245 Column {
246 name: "user_id".to_owned(),
247 col_type: ColumnType::Integer,
248 auto_increment: false,
249 not_null: true,
250 unique: false,
251 unique_key: None,
252 },
253 ],
254 relations: vec![Relation {
255 ref_table: "user".to_owned(),
256 columns: vec!["user_id".to_owned()],
257 ref_columns: vec!["id".to_owned()],
258 rel_type: RelationType::BelongsTo,
259 on_delete: None,
260 on_update: None,
261 self_referencing: false,
262 num_suffix: 0,
263 impl_related: true,
264 }],
265 conjunct_relations: vec![ConjunctRelation {
266 via: "post_tag".to_owned(),
267 to: "tag".to_owned(),
268 }],
269 primary_keys: vec![PrimaryKey {
270 name: "id".to_owned(),
271 }],
272 },
273 Entity {
274 table_name: "tag".to_owned(),
275 columns: vec![
276 Column {
277 name: "id".to_owned(),
278 col_type: ColumnType::Integer,
279 auto_increment: true,
280 not_null: true,
281 unique: false,
282 unique_key: None,
283 },
284 Column {
285 name: "name".to_owned(),
286 col_type: ColumnType::String(StringLen::N(100)),
287 auto_increment: false,
288 not_null: true,
289 unique: true,
290 unique_key: None,
291 },
292 ],
293 relations: vec![],
294 conjunct_relations: vec![ConjunctRelation {
295 via: "post_tag".to_owned(),
296 to: "post".to_owned(),
297 }],
298 primary_keys: vec![PrimaryKey {
299 name: "id".to_owned(),
300 }],
301 },
302 Entity {
303 table_name: "post_tag".to_owned(),
304 columns: vec![
305 Column {
306 name: "post_id".to_owned(),
307 col_type: ColumnType::Integer,
308 auto_increment: false,
309 not_null: true,
310 unique: false,
311 unique_key: None,
312 },
313 Column {
314 name: "tag_id".to_owned(),
315 col_type: ColumnType::Integer,
316 auto_increment: false,
317 not_null: true,
318 unique: false,
319 unique_key: None,
320 },
321 ],
322 relations: vec![
323 Relation {
324 ref_table: "post".to_owned(),
325 columns: vec!["post_id".to_owned()],
326 ref_columns: vec!["id".to_owned()],
327 rel_type: RelationType::BelongsTo,
328 on_delete: None,
329 on_update: None,
330 self_referencing: false,
331 num_suffix: 0,
332 impl_related: true,
333 },
334 Relation {
335 ref_table: "tag".to_owned(),
336 columns: vec!["tag_id".to_owned()],
337 ref_columns: vec!["id".to_owned()],
338 rel_type: RelationType::BelongsTo,
339 on_delete: None,
340 on_update: None,
341 self_referencing: false,
342 num_suffix: 0,
343 impl_related: true,
344 },
345 ],
346 conjunct_relations: vec![],
347 primary_keys: vec![
348 PrimaryKey {
349 name: "post_id".to_owned(),
350 },
351 PrimaryKey {
352 name: "tag_id".to_owned(),
353 },
354 ],
355 },
356 ],
357 enums: BTreeMap::new(),
358 }
359 }
360
361 #[test]
362 fn test_generate_er_diagram() {
363 let writer = setup_blog_schema();
364 let diagram = writer.generate_er_diagram();
365
366 let expected = r#"erDiagram
367 user {
368 int id PK
369 varchar name
370 varchar email UK
371 int parent_id FK
372 }
373 post {
374 int id PK
375 text title
376 int user_id FK
377 }
378 tag {
379 int id PK
380 varchar name UK
381 }
382 post_tag {
383 int post_id PK,FK
384 int tag_id PK,FK
385 }
386 user }o--|| user : "parent_id"
387 post }o--|| user : "user_id"
388 post }o--o{ tag : "[post_tag]"
389 post_tag }o--|| post : "post_id"
390 post_tag }o--|| tag : "tag_id"
391"#;
392
393 assert_eq!(diagram, expected);
394 }
395
396 #[test]
397 fn test_er_diagram_deduplicates_m2m() {
398 let writer = setup_blog_schema();
399 let diagram = writer.generate_er_diagram();
400
401 let m2m_count = diagram.matches("}o--o{").count();
402 assert_eq!(m2m_count, 1, "M-N relation should appear only once");
403 }
404}