sql_splitter/graph/format/
mermaid.rs

1//! Mermaid erDiagram format output.
2
3use crate::graph::view::GraphView;
4
5/// Generate Mermaid erDiagram from a graph view
6pub fn to_mermaid(view: &GraphView) -> String {
7    let mut output = String::new();
8
9    // Use erDiagram for proper ERD visualization
10    output.push_str("erDiagram\n");
11
12    // Generate entity definitions with attributes
13    for table in view.sorted_tables() {
14        let safe_name = escape_mermaid_id(&table.name);
15        output.push_str(&format!("    {} {{\n", safe_name));
16
17        for col in &table.columns {
18            let key_marker = if col.is_primary_key {
19                "PK"
20            } else if col.is_foreign_key {
21                "FK"
22            } else {
23                ""
24            };
25
26            let col_type = escape_mermaid_type(&col.col_type);
27            let col_name = escape_mermaid_id(&col.name);
28
29            if key_marker.is_empty() {
30                output.push_str(&format!("        {} {}\n", col_type, col_name));
31            } else {
32                output.push_str(&format!(
33                    "        {} {} {}\n",
34                    col_type, col_name, key_marker
35                ));
36            }
37        }
38
39        output.push_str("    }\n");
40    }
41
42    if !view.edges.is_empty() {
43        output.push('\n');
44    }
45
46    // Generate relationships
47    for edge in &view.edges {
48        let from = escape_mermaid_id(&edge.from_table);
49        let to = escape_mermaid_id(&edge.to_table);
50        let cardinality = edge.cardinality.as_mermaid();
51        let label = edge.from_column.clone();
52
53        output.push_str(&format!(
54            "    {} {} {} : \"{}\"\n",
55            from, cardinality, to, label
56        ));
57    }
58
59    output
60}
61
62/// Escape a string for use as a Mermaid entity ID
63fn escape_mermaid_id(s: &str) -> String {
64    // Mermaid IDs should be alphanumeric with underscores
65    s.chars()
66        .map(|c| if c.is_alphanumeric() || c == '_' { c } else { '_' })
67        .collect()
68}
69
70/// Escape a type string for Mermaid (no spaces, special chars)
71fn escape_mermaid_type(s: &str) -> String {
72    // Remove parentheses content for cleaner display
73    let base = if let Some(paren_pos) = s.find('(') {
74        &s[..paren_pos]
75    } else {
76        s
77    };
78    base.chars()
79        .map(|c| {
80            if c.is_alphanumeric() || c == '_' {
81                c
82            } else {
83                '_'
84            }
85        })
86        .collect()
87}
88
89#[cfg(test)]
90mod tests {
91    use super::*;
92    use crate::graph::view::{Cardinality, ColumnInfo, EdgeInfo, TableInfo};
93    use ahash::AHashMap;
94
95    fn create_test_view() -> GraphView {
96        let mut tables = AHashMap::new();
97
98        tables.insert(
99            "users".to_string(),
100            TableInfo {
101                name: "users".to_string(),
102                columns: vec![
103                    ColumnInfo {
104                        name: "id".to_string(),
105                        col_type: "INT".to_string(),
106                        is_primary_key: true,
107                        is_foreign_key: false,
108                        is_nullable: false,
109                        references_table: None,
110                        references_column: None,
111                    },
112                    ColumnInfo {
113                        name: "email".to_string(),
114                        col_type: "VARCHAR(255)".to_string(),
115                        is_primary_key: false,
116                        is_foreign_key: false,
117                        is_nullable: true,
118                        references_table: None,
119                        references_column: None,
120                    },
121                ],
122            },
123        );
124
125        tables.insert(
126            "orders".to_string(),
127            TableInfo {
128                name: "orders".to_string(),
129                columns: vec![
130                    ColumnInfo {
131                        name: "id".to_string(),
132                        col_type: "INT".to_string(),
133                        is_primary_key: true,
134                        is_foreign_key: false,
135                        is_nullable: false,
136                        references_table: None,
137                        references_column: None,
138                    },
139                    ColumnInfo {
140                        name: "user_id".to_string(),
141                        col_type: "INT".to_string(),
142                        is_primary_key: false,
143                        is_foreign_key: true,
144                        is_nullable: false,
145                        references_table: Some("users".to_string()),
146                        references_column: Some("id".to_string()),
147                    },
148                ],
149            },
150        );
151
152        let edges = vec![EdgeInfo {
153            from_table: "orders".to_string(),
154            from_column: "user_id".to_string(),
155            to_table: "users".to_string(),
156            to_column: "id".to_string(),
157            cardinality: Cardinality::ManyToOne,
158        }];
159
160        GraphView { tables, edges }
161    }
162
163    #[test]
164    fn test_mermaid_er_diagram() {
165        let view = create_test_view();
166        let output = to_mermaid(&view);
167
168        assert!(output.contains("erDiagram"));
169        assert!(output.contains("users {"));
170        assert!(output.contains("orders {"));
171    }
172
173    #[test]
174    fn test_mermaid_columns() {
175        let view = create_test_view();
176        let output = to_mermaid(&view);
177
178        assert!(output.contains("INT id PK"));
179        assert!(output.contains("INT user_id FK"));
180        assert!(output.contains("VARCHAR email"));
181    }
182
183    #[test]
184    fn test_mermaid_relationships() {
185        let view = create_test_view();
186        let output = to_mermaid(&view);
187
188        assert!(output.contains("}o--||"));
189        assert!(output.contains(": \"user_id\""));
190    }
191}