Skip to main content

sea_orm_codegen/entity/writer/
mermaid.rs

1use std::collections::{BTreeSet, HashSet};
2use std::fmt::Write;
3
4use sea_query::ColumnType;
5
6use crate::{Entity, RelationType};
7
8use super::EntityWriter;
9
10impl EntityWriter {
11    pub fn generate_er_diagram(&self) -> String {
12        let mut out = String::from("erDiagram\n");
13
14        let pk_sets: Vec<HashSet<&str>> = self
15            .entities
16            .iter()
17            .map(|e| e.primary_keys.iter().map(|pk| pk.name.as_str()).collect())
18            .collect();
19
20        let fk_sets: Vec<HashSet<&str>> = self
21            .entities
22            .iter()
23            .map(|e| {
24                e.relations
25                    .iter()
26                    .filter(|r| matches!(r.rel_type, RelationType::BelongsTo))
27                    .flat_map(|r| r.columns.iter().map(String::as_str))
28                    .collect()
29            })
30            .collect();
31
32        for (i, entity) in self.entities.iter().enumerate() {
33            write_entity_block(&mut out, entity, &pk_sets[i], &fk_sets[i]);
34        }
35
36        let mut emitted: BTreeSet<String> = BTreeSet::new();
37
38        for entity in &self.entities {
39            write_relations(&mut out, entity, &mut emitted);
40        }
41
42        out
43    }
44}
45
46fn write_entity_block(out: &mut String, entity: &Entity, pks: &HashSet<&str>, fks: &HashSet<&str>) {
47    let _ = writeln!(out, "    {} {{", entity.table_name);
48
49    for col in &entity.columns {
50        let type_name = col_type_name(&col.col_type);
51        let is_pk = pks.contains(col.name.as_str());
52        let is_fk = fks.contains(col.name.as_str());
53        let is_uk = col.unique || col.unique_key.is_some();
54
55        let constraint = match (is_pk, is_fk, is_uk) {
56            (true, true, _) => " PK,FK",
57            (true, false, _) => " PK",
58            (false, true, true) => " FK,UK",
59            (false, true, false) => " FK",
60            (false, false, true) => " UK",
61            (false, false, false) => "",
62        };
63
64        let _ = writeln!(out, "        {} {}{}", type_name, col.name, constraint);
65    }
66
67    let _ = writeln!(out, "    }}");
68}
69
70fn write_relations(out: &mut String, entity: &Entity, emitted: &mut BTreeSet<String>) {
71    for rel in &entity.relations {
72        let (left, right, cardinality, label) = match rel.rel_type {
73            RelationType::BelongsTo => (
74                &entity.table_name,
75                &rel.ref_table,
76                "}o--||",
77                rel.columns.join(", "),
78            ),
79            RelationType::HasOne => continue,
80            RelationType::HasMany => continue,
81        };
82
83        let key = format!("{left} {cardinality} {right} : \"{label}\"");
84        if emitted.insert(key.clone()) {
85            let _ = writeln!(out, "    {key}");
86        }
87    }
88
89    for conj in &entity.conjunct_relations {
90        let left = &entity.table_name;
91        let right = &conj.to;
92        let label = format!("[{}]", conj.via);
93
94        let key = if left <= right {
95            format!("{left} }}o--o{{ {right} : \"{label}\"")
96        } else {
97            format!("{right} }}o--o{{ {left} : \"{label}\"")
98        };
99
100        if emitted.insert(key.clone()) {
101            let _ = writeln!(out, "    {key}");
102        }
103    }
104}
105
106fn col_type_name(col_type: &ColumnType) -> &str {
107    #[allow(unreachable_patterns)]
108    match col_type {
109        ColumnType::Char(_) => "char",
110        ColumnType::String(_) => "varchar",
111        ColumnType::Text => "text",
112        ColumnType::TinyInteger => "tinyint",
113        ColumnType::SmallInteger => "smallint",
114        ColumnType::Integer => "int",
115        ColumnType::BigInteger => "bigint",
116        ColumnType::TinyUnsigned => "tinyint_unsigned",
117        ColumnType::SmallUnsigned => "smallint_unsigned",
118        ColumnType::Unsigned => "int_unsigned",
119        ColumnType::BigUnsigned => "bigint_unsigned",
120        ColumnType::Float => "float",
121        ColumnType::Double => "double",
122        ColumnType::Decimal(_) => "decimal",
123        ColumnType::Money(_) => "money",
124        ColumnType::DateTime => "datetime",
125        ColumnType::Timestamp => "timestamp",
126        ColumnType::TimestampWithTimeZone => "timestamptz",
127        ColumnType::Time => "time",
128        ColumnType::Date => "date",
129        ColumnType::Year => "year",
130        ColumnType::Binary(_) | ColumnType::VarBinary(_) | ColumnType::Blob => "blob",
131        ColumnType::Boolean => "bool",
132        ColumnType::Json | ColumnType::JsonBinary => "json",
133        ColumnType::Uuid => "uuid",
134        ColumnType::Enum { .. } => "enum",
135        ColumnType::Array(_) => "array",
136        ColumnType::Vector(_) => "vector",
137        ColumnType::Bit(_) | ColumnType::VarBit(_) => "bit",
138        ColumnType::Cidr => "cidr",
139        ColumnType::Inet => "inet",
140        ColumnType::MacAddr => "macaddr",
141        ColumnType::LTree => "ltree",
142        ColumnType::Interval(_, _) => "interval",
143        ColumnType::Custom(_) => "custom",
144        _ => "unknown",
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use std::collections::BTreeMap;
151
152    use sea_query::{ColumnType, StringLen};
153
154    use crate::{
155        Column, ConjunctRelation, Entity, EntityWriter, PrimaryKey, Relation, RelationType,
156    };
157
158    fn setup_blog_schema() -> EntityWriter {
159        EntityWriter {
160            entities: vec![
161                Entity {
162                    table_name: "user".to_owned(),
163                    columns: vec![
164                        Column {
165                            name: "id".to_owned(),
166                            col_type: ColumnType::Integer,
167                            auto_increment: true,
168                            not_null: true,
169                            unique: false,
170                            unique_key: None,
171                        },
172                        Column {
173                            name: "name".to_owned(),
174                            col_type: ColumnType::String(StringLen::N(255)),
175                            auto_increment: false,
176                            not_null: true,
177                            unique: false,
178                            unique_key: None,
179                        },
180                        Column {
181                            name: "email".to_owned(),
182                            col_type: ColumnType::String(StringLen::N(255)),
183                            auto_increment: false,
184                            not_null: true,
185                            unique: true,
186                            unique_key: None,
187                        },
188                        Column {
189                            name: "parent_id".to_owned(),
190                            col_type: ColumnType::Integer,
191                            auto_increment: false,
192                            not_null: false,
193                            unique: false,
194                            unique_key: None,
195                        },
196                    ],
197                    relations: vec![
198                        Relation {
199                            ref_table: "post".to_owned(),
200                            columns: vec![],
201                            ref_columns: vec![],
202                            rel_type: RelationType::HasMany,
203                            on_delete: None,
204                            on_update: None,
205                            self_referencing: false,
206                            num_suffix: 0,
207                            impl_related: true,
208                        },
209                        Relation {
210                            ref_table: "user".to_owned(),
211                            columns: vec!["parent_id".to_owned()],
212                            ref_columns: vec!["id".to_owned()],
213                            rel_type: RelationType::BelongsTo,
214                            on_delete: None,
215                            on_update: None,
216                            self_referencing: true,
217                            num_suffix: 0,
218                            impl_related: true,
219                        },
220                    ],
221                    conjunct_relations: vec![],
222                    primary_keys: vec![PrimaryKey {
223                        name: "id".to_owned(),
224                    }],
225                },
226                Entity {
227                    table_name: "post".to_owned(),
228                    columns: vec![
229                        Column {
230                            name: "id".to_owned(),
231                            col_type: ColumnType::Integer,
232                            auto_increment: true,
233                            not_null: true,
234                            unique: false,
235                            unique_key: None,
236                        },
237                        Column {
238                            name: "title".to_owned(),
239                            col_type: ColumnType::Text,
240                            auto_increment: false,
241                            not_null: true,
242                            unique: false,
243                            unique_key: None,
244                        },
245                        Column {
246                            name: "user_id".to_owned(),
247                            col_type: ColumnType::Integer,
248                            auto_increment: false,
249                            not_null: true,
250                            unique: false,
251                            unique_key: None,
252                        },
253                    ],
254                    relations: vec![Relation {
255                        ref_table: "user".to_owned(),
256                        columns: vec!["user_id".to_owned()],
257                        ref_columns: vec!["id".to_owned()],
258                        rel_type: RelationType::BelongsTo,
259                        on_delete: None,
260                        on_update: None,
261                        self_referencing: false,
262                        num_suffix: 0,
263                        impl_related: true,
264                    }],
265                    conjunct_relations: vec![ConjunctRelation {
266                        via: "post_tag".to_owned(),
267                        to: "tag".to_owned(),
268                    }],
269                    primary_keys: vec![PrimaryKey {
270                        name: "id".to_owned(),
271                    }],
272                },
273                Entity {
274                    table_name: "tag".to_owned(),
275                    columns: vec![
276                        Column {
277                            name: "id".to_owned(),
278                            col_type: ColumnType::Integer,
279                            auto_increment: true,
280                            not_null: true,
281                            unique: false,
282                            unique_key: None,
283                        },
284                        Column {
285                            name: "name".to_owned(),
286                            col_type: ColumnType::String(StringLen::N(100)),
287                            auto_increment: false,
288                            not_null: true,
289                            unique: true,
290                            unique_key: None,
291                        },
292                    ],
293                    relations: vec![],
294                    conjunct_relations: vec![ConjunctRelation {
295                        via: "post_tag".to_owned(),
296                        to: "post".to_owned(),
297                    }],
298                    primary_keys: vec![PrimaryKey {
299                        name: "id".to_owned(),
300                    }],
301                },
302                Entity {
303                    table_name: "post_tag".to_owned(),
304                    columns: vec![
305                        Column {
306                            name: "post_id".to_owned(),
307                            col_type: ColumnType::Integer,
308                            auto_increment: false,
309                            not_null: true,
310                            unique: false,
311                            unique_key: None,
312                        },
313                        Column {
314                            name: "tag_id".to_owned(),
315                            col_type: ColumnType::Integer,
316                            auto_increment: false,
317                            not_null: true,
318                            unique: false,
319                            unique_key: None,
320                        },
321                    ],
322                    relations: vec![
323                        Relation {
324                            ref_table: "post".to_owned(),
325                            columns: vec!["post_id".to_owned()],
326                            ref_columns: vec!["id".to_owned()],
327                            rel_type: RelationType::BelongsTo,
328                            on_delete: None,
329                            on_update: None,
330                            self_referencing: false,
331                            num_suffix: 0,
332                            impl_related: true,
333                        },
334                        Relation {
335                            ref_table: "tag".to_owned(),
336                            columns: vec!["tag_id".to_owned()],
337                            ref_columns: vec!["id".to_owned()],
338                            rel_type: RelationType::BelongsTo,
339                            on_delete: None,
340                            on_update: None,
341                            self_referencing: false,
342                            num_suffix: 0,
343                            impl_related: true,
344                        },
345                    ],
346                    conjunct_relations: vec![],
347                    primary_keys: vec![
348                        PrimaryKey {
349                            name: "post_id".to_owned(),
350                        },
351                        PrimaryKey {
352                            name: "tag_id".to_owned(),
353                        },
354                    ],
355                },
356            ],
357            enums: BTreeMap::new(),
358        }
359    }
360
361    #[test]
362    fn test_generate_er_diagram() {
363        let writer = setup_blog_schema();
364        let diagram = writer.generate_er_diagram();
365
366        let expected = r#"erDiagram
367    user {
368        int id PK
369        varchar name
370        varchar email UK
371        int parent_id FK
372    }
373    post {
374        int id PK
375        text title
376        int user_id FK
377    }
378    tag {
379        int id PK
380        varchar name UK
381    }
382    post_tag {
383        int post_id PK,FK
384        int tag_id PK,FK
385    }
386    user }o--|| user : "parent_id"
387    post }o--|| user : "user_id"
388    post }o--o{ tag : "[post_tag]"
389    post_tag }o--|| post : "post_id"
390    post_tag }o--|| tag : "tag_id"
391"#;
392
393        assert_eq!(diagram, expected);
394    }
395
396    #[test]
397    fn test_er_diagram_deduplicates_m2m() {
398        let writer = setup_blog_schema();
399        let diagram = writer.generate_er_diagram();
400
401        let m2m_count = diagram.matches("}o--o{").count();
402        assert_eq!(m2m_count, 1, "M-N relation should appear only once");
403    }
404}