1use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::EdgeRef;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum SchemaNode {
21 Table {
22 name: String,
23 estimated_rows: Option<u64>,
24 },
25 Column {
26 table: String,
27 name: String,
28 data_type: String,
29 nullable: bool,
30 },
31 Index {
32 name: String,
33 table: String,
34 unique: bool,
35 },
36}
37
38impl SchemaNode {
39 pub fn label(&self) -> String {
40 match self {
41 SchemaNode::Table { name, .. } => name.clone(),
42 SchemaNode::Column { table, name, .. } => format!("{}.{}", table, name),
43 SchemaNode::Index { name, table, .. } => format!("idx:{}@{}", name, table),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum SchemaEdge {
54 Contains,
56 ForeignKey {
58 constraint_name: Option<String>,
59 from_columns: Vec<String>,
60 to_columns: Vec<String>,
61 cascade_delete: bool,
62 cascade_update: bool,
63 },
64 HasIndex,
66}
67
68pub struct SchemaGraph {
73 pub graph: DiGraph<SchemaNode, SchemaEdge>,
74 pub table_index: HashMap<String, NodeIndex>,
76 pub column_index: HashMap<String, NodeIndex>,
78 pub index_index: HashMap<String, NodeIndex>,
80}
81
82impl SchemaGraph {
83 pub fn new() -> Self {
84 Self {
85 graph: DiGraph::new(),
86 table_index: HashMap::new(),
87 column_index: HashMap::new(),
88 index_index: HashMap::new(),
89 }
90 }
91
92 pub fn add_table(&mut self, name: &str, estimated_rows: Option<u64>) -> NodeIndex {
95 if let Some(&idx) = self.table_index.get(name) {
96 return idx;
97 }
98 let idx = self.graph.add_node(SchemaNode::Table {
99 name: name.to_string(),
100 estimated_rows,
101 });
102 self.table_index.insert(name.to_string(), idx);
103 idx
104 }
105
106 pub fn add_column(
107 &mut self,
108 table: &str,
109 name: &str,
110 data_type: &str,
111 nullable: bool,
112 ) -> NodeIndex {
113 let key = format!("{}.{}", table, name);
114 if let Some(&idx) = self.column_index.get(&key) {
115 return idx;
116 }
117 let idx = self.graph.add_node(SchemaNode::Column {
118 table: table.to_string(),
119 name: name.to_string(),
120 data_type: data_type.to_string(),
121 nullable,
122 });
123 self.column_index.insert(key, idx);
124
125 if let Some(&tidx) = self.table_index.get(table) {
127 self.graph.add_edge(tidx, idx, SchemaEdge::Contains);
128 }
129 idx
130 }
131
132 pub fn add_index(&mut self, index_name: &str, table: &str, unique: bool) -> NodeIndex {
133 if let Some(&idx) = self.index_index.get(index_name) {
134 return idx;
135 }
136 let idx = self.graph.add_node(SchemaNode::Index {
137 name: index_name.to_string(),
138 table: table.to_string(),
139 unique,
140 });
141 self.index_index.insert(index_name.to_string(), idx);
142
143 if let Some(&tidx) = self.table_index.get(table) {
145 self.graph.add_edge(tidx, idx, SchemaEdge::HasIndex);
146 }
147 idx
148 }
149
150 #[allow(clippy::too_many_arguments)]
151 pub fn add_foreign_key(
152 &mut self,
153 from_table: &str,
154 to_table: &str,
155 constraint_name: Option<String>,
156 from_columns: Vec<String>,
157 to_columns: Vec<String>,
158 cascade_delete: bool,
159 cascade_update: bool,
160 ) {
161 let from_idx = self.add_table(from_table, None);
163 let to_idx = self.add_table(to_table, None);
164 self.graph.add_edge(
165 from_idx,
166 to_idx,
167 SchemaEdge::ForeignKey {
168 constraint_name,
169 from_columns,
170 to_columns,
171 cascade_delete,
172 cascade_update,
173 },
174 );
175 }
176
177 pub fn tables_referencing(&self, table: &str) -> Vec<String> {
181 let Some(&tidx) = self.table_index.get(table) else {
182 return Vec::new();
183 };
184
185 use petgraph::Direction;
186 self.graph
187 .edges_directed(tidx, Direction::Incoming)
188 .filter_map(|e| {
189 if matches!(e.weight(), SchemaEdge::ForeignKey { .. }) {
190 if let SchemaNode::Table { name, .. } = &self.graph[e.source()] {
191 return Some(name.clone());
192 }
193 }
194 None
195 })
196 .collect()
197 }
198
199 pub fn fk_downstream(&self, table: &str) -> Vec<String> {
201 let Some(&tidx) = self.table_index.get(table) else {
202 return Vec::new();
203 };
204
205 use petgraph::visit::Dfs;
206 let mut dfs = Dfs::new(&self.graph, tidx);
207 let mut result = Vec::new();
208 while let Some(nx) = dfs.next(&self.graph) {
209 if nx == tidx {
210 continue;
211 }
212 if let SchemaNode::Table { name, .. } = &self.graph[nx] {
213 result.push(name.clone());
214 }
215 }
216 result
217 }
218
219 pub fn all_tables(&self) -> Vec<String> {
221 self.table_index.keys().cloned().collect()
222 }
223
224 pub fn text_summary(&self) -> String {
226 let mut lines = Vec::new();
227 for (name, &idx) in &self.table_index {
228 let refs: Vec<String> = self
229 .graph
230 .edges(idx)
231 .filter_map(|e| {
232 if let SchemaEdge::ForeignKey {
233 constraint_name, ..
234 } = e.weight()
235 {
236 if let SchemaNode::Table { name: tname, .. } = &self.graph[e.target()] {
237 let cn = constraint_name.as_deref().unwrap_or("unnamed");
238 return Some(format!(" FK({}) → {}", cn, tname));
239 }
240 }
241 None
242 })
243 .collect();
244
245 if refs.is_empty() {
246 lines.push(format!("[Table] {}", name));
247 } else {
248 lines.push(format!("[Table] {}", name));
249 lines.extend(refs);
250 }
251 }
252 lines.sort();
253 lines.join("\n")
254 }
255
256 pub fn export_mermaid(&self) -> String {
273 let mut out = String::from("erDiagram\n");
274
275 for (table_name, &table_idx) in &self.table_index {
277 out.push_str(&format!(" {} {{\n", sanitise_id(table_name)));
278
279 let mut col_lines: Vec<String> = self
281 .column_index
282 .iter()
283 .filter(|(key, _)| key.starts_with(&format!("{}.", table_name)))
284 .filter_map(|(_, &col_idx)| {
285 if let SchemaNode::Column {
286 name, data_type, ..
287 } = &self.graph[col_idx]
288 {
289 let is_pk = name == "id";
292 let pk_marker = if is_pk { " PK" } else { "" };
293 Some(format!(
294 " {} {}{}",
295 mermaid_type(data_type),
296 sanitise_id(name),
297 pk_marker
298 ))
299 } else {
300 None
301 }
302 })
303 .collect();
304 col_lines.sort();
305 for line in col_lines {
306 out.push_str(&line);
307 out.push('\n');
308 }
309
310 if let SchemaNode::Table {
312 estimated_rows: Some(rows),
313 ..
314 } = &self.graph[table_idx]
315 {
316 out.push_str(&format!(
317 " string __rows \"~{}\"\n",
318 human_rows(*rows)
319 ));
320 }
321
322 out.push_str(" }\n");
323 }
324
325 for &table_idx in self.table_index.values() {
327 for edge in self.graph.edges(table_idx) {
328 if let SchemaEdge::ForeignKey {
329 constraint_name,
330 from_columns,
331 ..
332 } = edge.weight()
333 {
334 let source = if let SchemaNode::Table { name, .. } = &self.graph[edge.source()]
336 {
337 name.clone()
338 } else {
339 continue;
340 };
341 let target = if let SchemaNode::Table { name, .. } = &self.graph[edge.target()]
342 {
343 name.clone()
344 } else {
345 continue;
346 };
347
348 let label = constraint_name.as_deref().unwrap_or_else(|| {
349 from_columns.first().map(|s| s.as_str()).unwrap_or("fk")
350 });
351
352 out.push_str(&format!(
353 " {} }}o--|| {} : \"{}\"\n",
354 sanitise_id(&source),
355 sanitise_id(&target),
356 label
357 ));
358 }
359 }
360 }
361
362 out
363 }
364
365 pub fn export_graphviz(&self) -> String {
369 let mut out = String::from(
370 "digraph schema {\n \
371 rankdir=LR;\n \
372 node [shape=record, fontsize=11, fontname=\"Helvetica\"];\n \
373 edge [fontsize=9];\n\n",
374 );
375
376 for (table_name, &table_idx) in &self.table_index {
378 let row_info = if let SchemaNode::Table {
379 estimated_rows: Some(rows),
380 ..
381 } = &self.graph[table_idx]
382 {
383 format!(" (~{})", human_rows(*rows))
384 } else {
385 String::new()
386 };
387
388 let col_labels: Vec<String> = self
389 .column_index
390 .iter()
391 .filter(|(key, _)| key.starts_with(&format!("{}.", table_name)))
392 .filter_map(|(_, &col_idx)| {
393 if let SchemaNode::Column {
394 name,
395 data_type,
396 nullable,
397 ..
398 } = &self.graph[col_idx]
399 {
400 let null_marker = if *nullable { "?" } else { "" };
401 Some(format!(
402 "{{{}{}|{}}}",
403 dot_escape(name),
404 null_marker,
405 mermaid_type(data_type)
406 ))
407 } else {
408 None
409 }
410 })
411 .collect();
412
413 let columns_str = if col_labels.is_empty() {
414 String::new()
415 } else {
416 format!("|{}", col_labels.join("|"))
417 };
418
419 out.push_str(&format!(
420 " {} [label=\"{{{}{}{}}}\" fillcolor=\"#dae8fc\" style=filled];\n",
421 sanitise_id(table_name),
422 dot_escape(table_name),
423 row_info,
424 columns_str,
425 ));
426 }
427
428 out.push('\n');
429
430 for &table_idx in self.table_index.values() {
432 for edge in self.graph.edges(table_idx) {
433 if let SchemaEdge::ForeignKey {
434 constraint_name,
435 from_columns,
436 cascade_delete,
437 ..
438 } = edge.weight()
439 {
440 let source = if let SchemaNode::Table { name, .. } = &self.graph[edge.source()]
441 {
442 name.clone()
443 } else {
444 continue;
445 };
446 let target = if let SchemaNode::Table { name, .. } = &self.graph[edge.target()]
447 {
448 name.clone()
449 } else {
450 continue;
451 };
452
453 let label = constraint_name.as_deref().unwrap_or_else(|| {
454 from_columns.first().map(|s| s.as_str()).unwrap_or("fk")
455 });
456
457 let style = if *cascade_delete { "dashed" } else { "solid" };
458
459 out.push_str(&format!(
460 " {} -> {} [label=\"{}\" style=\"{}\"];\n",
461 sanitise_id(&source),
462 sanitise_id(&target),
463 dot_escape(label),
464 style,
465 ));
466 }
467 }
468 }
469
470 out.push_str("}\n");
471 out
472 }
473}
474
475fn sanitise_id(name: &str) -> String {
482 name.chars()
483 .map(|c| {
484 if c.is_alphanumeric() || c == '_' {
485 c
486 } else {
487 '_'
488 }
489 })
490 .collect()
491}
492
493fn dot_escape(s: &str) -> String {
495 s.replace('\\', "\\\\")
496 .replace('"', "\\\"")
497 .replace('{', "\\{")
498 .replace('}', "\\}")
499 .replace('<', "\\<")
500 .replace('>', "\\>")
501 .replace('|', "\\|")
502}
503
504fn mermaid_type(pg_type: &str) -> &str {
506 let lower = pg_type.to_lowercase();
507 if lower.contains("bigint") || lower.contains("int8") {
508 "bigint"
509 } else if lower.contains("int") {
510 "int"
511 } else if lower.contains("bool") {
512 "boolean"
513 } else if lower.contains("text") || lower.contains("varchar") || lower.contains("char") {
514 "string"
515 } else if lower.contains("timestamp") || lower.contains("date") {
516 "datetime"
517 } else if lower.contains("uuid") {
518 "uuid"
519 } else if lower.contains("json") {
520 "json"
521 } else if lower.contains("float")
522 || lower.contains("real")
523 || lower.contains("double")
524 || lower.contains("numeric")
525 || lower.contains("decimal")
526 {
527 "float"
528 } else if lower.contains("bytea") {
529 "bytes"
530 } else {
531 "string"
532 }
533}
534
535fn human_rows(n: u64) -> String {
537 if n >= 1_000_000_000 {
538 format!("{:.1}B", n as f64 / 1_000_000_000.0)
539 } else if n >= 1_000_000 {
540 format!("{:.1}M", n as f64 / 1_000_000.0)
541 } else if n >= 1_000 {
542 format!("{:.1}K", n as f64 / 1_000.0)
543 } else {
544 n.to_string()
545 }
546}
547
548impl Default for SchemaGraph {
549 fn default() -> Self {
550 Self::new()
551 }
552}