Skip to main content

schema_risk/
graph.rs

1//! Schema graph – represents tables, columns, and their dependencies.
2//!
3//! We use `petgraph::DiGraph` where:
4//!   - Each node is a `SchemaNode` (Table or Column)
5//!   - Each edge is a `SchemaEdge` (ForeignKey, Contains, DependsOn)
6//!
7//! The graph lets us answer questions like:
8//!   "If I drop table X, what else breaks?"
9
10use petgraph::graph::{DiGraph, NodeIndex};
11use petgraph::visit::EdgeRef;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14
15// ─────────────────────────────────────────────
16// Node types
17// ─────────────────────────────────────────────
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub enum SchemaNode {
21    Table {
22        name: String,
23        estimated_rows: Option<u64>,
24    },
25    Column {
26        table: String,
27        name: String,
28        data_type: String,
29        nullable: bool,
30    },
31    Index {
32        name: String,
33        table: String,
34        unique: bool,
35    },
36}
37
38impl SchemaNode {
39    pub fn label(&self) -> String {
40        match self {
41            SchemaNode::Table { name, .. } => name.clone(),
42            SchemaNode::Column { table, name, .. } => format!("{}.{}", table, name),
43            SchemaNode::Index { name, table, .. } => format!("idx:{}@{}", name, table),
44        }
45    }
46}
47
48// ─────────────────────────────────────────────
49// Edge types
50// ─────────────────────────────────────────────
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub enum SchemaEdge {
54    /// Table → Column  (table contains this column)
55    Contains,
56    /// Table → Table   (foreign key relationship)
57    ForeignKey {
58        constraint_name: Option<String>,
59        from_columns: Vec<String>,
60        to_columns: Vec<String>,
61        cascade_delete: bool,
62        cascade_update: bool,
63    },
64    /// Table → Index
65    HasIndex,
66}
67
68// ─────────────────────────────────────────────
69// The graph structure
70// ─────────────────────────────────────────────
71
72pub struct SchemaGraph {
73    pub graph: DiGraph<SchemaNode, SchemaEdge>,
74    /// table_name -> NodeIndex
75    pub table_index: HashMap<String, NodeIndex>,
76    /// "table.column" -> NodeIndex
77    pub column_index: HashMap<String, NodeIndex>,
78    /// index_name -> NodeIndex
79    pub index_index: HashMap<String, NodeIndex>,
80}
81
82impl SchemaGraph {
83    pub fn new() -> Self {
84        Self {
85            graph: DiGraph::new(),
86            table_index: HashMap::new(),
87            column_index: HashMap::new(),
88            index_index: HashMap::new(),
89        }
90    }
91
92    // ── Insertion helpers ─────────────────────────────────────────────────
93
94    pub fn add_table(&mut self, name: &str, estimated_rows: Option<u64>) -> NodeIndex {
95        if let Some(&idx) = self.table_index.get(name) {
96            return idx;
97        }
98        let idx = self.graph.add_node(SchemaNode::Table {
99            name: name.to_string(),
100            estimated_rows,
101        });
102        self.table_index.insert(name.to_string(), idx);
103        idx
104    }
105
106    pub fn add_column(
107        &mut self,
108        table: &str,
109        name: &str,
110        data_type: &str,
111        nullable: bool,
112    ) -> NodeIndex {
113        let key = format!("{}.{}", table, name);
114        if let Some(&idx) = self.column_index.get(&key) {
115            return idx;
116        }
117        let idx = self.graph.add_node(SchemaNode::Column {
118            table: table.to_string(),
119            name: name.to_string(),
120            data_type: data_type.to_string(),
121            nullable,
122        });
123        self.column_index.insert(key, idx);
124
125        // Connect table → column
126        if let Some(&tidx) = self.table_index.get(table) {
127            self.graph.add_edge(tidx, idx, SchemaEdge::Contains);
128        }
129        idx
130    }
131
132    pub fn add_index(&mut self, index_name: &str, table: &str, unique: bool) -> NodeIndex {
133        if let Some(&idx) = self.index_index.get(index_name) {
134            return idx;
135        }
136        let idx = self.graph.add_node(SchemaNode::Index {
137            name: index_name.to_string(),
138            table: table.to_string(),
139            unique,
140        });
141        self.index_index.insert(index_name.to_string(), idx);
142
143        // Connect table → index
144        if let Some(&tidx) = self.table_index.get(table) {
145            self.graph.add_edge(tidx, idx, SchemaEdge::HasIndex);
146        }
147        idx
148    }
149
150    #[allow(clippy::too_many_arguments)]
151    pub fn add_foreign_key(
152        &mut self,
153        from_table: &str,
154        to_table: &str,
155        constraint_name: Option<String>,
156        from_columns: Vec<String>,
157        to_columns: Vec<String>,
158        cascade_delete: bool,
159        cascade_update: bool,
160    ) {
161        // Make sure both tables exist as nodes (create ghost nodes if not yet seen)
162        let from_idx = self.add_table(from_table, None);
163        let to_idx = self.add_table(to_table, None);
164        self.graph.add_edge(
165            from_idx,
166            to_idx,
167            SchemaEdge::ForeignKey {
168                constraint_name,
169                from_columns,
170                to_columns,
171                cascade_delete,
172                cascade_update,
173            },
174        );
175    }
176
177    // ── Query helpers ─────────────────────────────────────────────────────
178
179    /// Returns all tables that hold a foreign key pointing TO the given table.
180    pub fn tables_referencing(&self, table: &str) -> Vec<String> {
181        let Some(&tidx) = self.table_index.get(table) else {
182            return Vec::new();
183        };
184
185        use petgraph::Direction;
186        self.graph
187            .edges_directed(tidx, Direction::Incoming)
188            .filter_map(|e| {
189                if matches!(e.weight(), SchemaEdge::ForeignKey { .. }) {
190                    if let SchemaNode::Table { name, .. } = &self.graph[e.source()] {
191                        return Some(name.clone());
192                    }
193                }
194                None
195            })
196            .collect()
197    }
198
199    /// Depth-first search: all tables reachable from `table` via FK edges.
200    pub fn fk_downstream(&self, table: &str) -> Vec<String> {
201        let Some(&tidx) = self.table_index.get(table) else {
202            return Vec::new();
203        };
204
205        use petgraph::visit::Dfs;
206        let mut dfs = Dfs::new(&self.graph, tidx);
207        let mut result = Vec::new();
208        while let Some(nx) = dfs.next(&self.graph) {
209            if nx == tidx {
210                continue;
211            }
212            if let SchemaNode::Table { name, .. } = &self.graph[nx] {
213                result.push(name.clone());
214            }
215        }
216        result
217    }
218
219    /// List all tables in the graph.
220    pub fn all_tables(&self) -> Vec<String> {
221        self.table_index.keys().cloned().collect()
222    }
223
224    /// Produces a plain-text adjacency summary of the graph.
225    pub fn text_summary(&self) -> String {
226        let mut lines = Vec::new();
227        for (name, &idx) in &self.table_index {
228            let refs: Vec<String> = self
229                .graph
230                .edges(idx)
231                .filter_map(|e| {
232                    if let SchemaEdge::ForeignKey {
233                        constraint_name, ..
234                    } = e.weight()
235                    {
236                        if let SchemaNode::Table { name: tname, .. } = &self.graph[e.target()] {
237                            let cn = constraint_name.as_deref().unwrap_or("unnamed");
238                            return Some(format!("  FK({}) → {}", cn, tname));
239                        }
240                    }
241                    None
242                })
243                .collect();
244
245            if refs.is_empty() {
246                lines.push(format!("[Table] {}", name));
247            } else {
248                lines.push(format!("[Table] {}", name));
249                lines.extend(refs);
250            }
251        }
252        lines.sort();
253        lines.join("\n")
254    }
255
256    // ── Graph export ──────────────────────────────────────────────────────
257
258    /// Export the schema as a Mermaid ER diagram.
259    ///
260    /// Output can be embedded in a Markdown file and rendered by GitHub,
261    /// GitLab, Notion, etc.
262    ///
263    /// Example:
264    /// ```mermaid
265    /// erDiagram
266    ///     users {
267    ///         uuid id PK
268    ///         text email
269    ///     }
270    ///     orders ||--o{ users : "user_id"
271    /// ```
272    pub fn export_mermaid(&self) -> String {
273        let mut out = String::from("erDiagram\n");
274
275        // ── Table definitions ─────────────────────────────────────────────
276        for (table_name, &table_idx) in &self.table_index {
277            out.push_str(&format!("    {} {{\n", sanitise_id(table_name)));
278
279            // Columns that belong to this table
280            let mut col_lines: Vec<String> = self
281                .column_index
282                .iter()
283                .filter(|(key, _)| key.starts_with(&format!("{}.", table_name)))
284                .filter_map(|(_, &col_idx)| {
285                    if let SchemaNode::Column {
286                        name, data_type, ..
287                    } = &self.graph[col_idx]
288                    {
289                        // Detect primary key by checking if the column name is "id" or
290                        // if the table has an index that covers only this column.
291                        let is_pk = name == "id";
292                        let pk_marker = if is_pk { " PK" } else { "" };
293                        Some(format!(
294                            "        {} {}{}",
295                            mermaid_type(data_type),
296                            sanitise_id(name),
297                            pk_marker
298                        ))
299                    } else {
300                        None
301                    }
302                })
303                .collect();
304            col_lines.sort();
305            for line in col_lines {
306                out.push_str(&line);
307                out.push('\n');
308            }
309
310            // Row estimate comment
311            if let SchemaNode::Table {
312                estimated_rows: Some(rows),
313                ..
314            } = &self.graph[table_idx]
315            {
316                out.push_str(&format!(
317                    "        string __rows \"~{}\"\n",
318                    human_rows(*rows)
319                ));
320            }
321
322            out.push_str("    }\n");
323        }
324
325        // ── Relationships ─────────────────────────────────────────────────
326        for &table_idx in self.table_index.values() {
327            for edge in self.graph.edges(table_idx) {
328                if let SchemaEdge::ForeignKey {
329                    constraint_name,
330                    from_columns,
331                    ..
332                } = edge.weight()
333                {
334                    // Determine source and target table names
335                    let source = if let SchemaNode::Table { name, .. } = &self.graph[edge.source()]
336                    {
337                        name.clone()
338                    } else {
339                        continue;
340                    };
341                    let target = if let SchemaNode::Table { name, .. } = &self.graph[edge.target()]
342                    {
343                        name.clone()
344                    } else {
345                        continue;
346                    };
347
348                    let label = constraint_name.as_deref().unwrap_or_else(|| {
349                        from_columns.first().map(|s| s.as_str()).unwrap_or("fk")
350                    });
351
352                    out.push_str(&format!(
353                        "    {} }}o--|| {} : \"{}\"\n",
354                        sanitise_id(&source),
355                        sanitise_id(&target),
356                        label
357                    ));
358                }
359            }
360        }
361
362        out
363    }
364
365    /// Export the schema as a Graphviz DOT document.
366    ///
367    /// Pipe to `dot -Tsvg -o schema.svg` or `dot -Tpng -o schema.png`.
368    pub fn export_graphviz(&self) -> String {
369        let mut out = String::from(
370            "digraph schema {\n  \
371             rankdir=LR;\n  \
372             node [shape=record, fontsize=11, fontname=\"Helvetica\"];\n  \
373             edge [fontsize=9];\n\n",
374        );
375
376        // ── Table nodes (record shape with column list) ───────────────────
377        for (table_name, &table_idx) in &self.table_index {
378            let row_info = if let SchemaNode::Table {
379                estimated_rows: Some(rows),
380                ..
381            } = &self.graph[table_idx]
382            {
383                format!(" (~{})", human_rows(*rows))
384            } else {
385                String::new()
386            };
387
388            let col_labels: Vec<String> = self
389                .column_index
390                .iter()
391                .filter(|(key, _)| key.starts_with(&format!("{}.", table_name)))
392                .filter_map(|(_, &col_idx)| {
393                    if let SchemaNode::Column {
394                        name,
395                        data_type,
396                        nullable,
397                        ..
398                    } = &self.graph[col_idx]
399                    {
400                        let null_marker = if *nullable { "?" } else { "" };
401                        Some(format!(
402                            "{{{}{}|{}}}",
403                            dot_escape(name),
404                            null_marker,
405                            mermaid_type(data_type)
406                        ))
407                    } else {
408                        None
409                    }
410                })
411                .collect();
412
413            let columns_str = if col_labels.is_empty() {
414                String::new()
415            } else {
416                format!("|{}", col_labels.join("|"))
417            };
418
419            out.push_str(&format!(
420                "  {} [label=\"{{{}{}{}}}\" fillcolor=\"#dae8fc\" style=filled];\n",
421                sanitise_id(table_name),
422                dot_escape(table_name),
423                row_info,
424                columns_str,
425            ));
426        }
427
428        out.push('\n');
429
430        // ── FK edges ─────────────────────────────────────────────────────
431        for &table_idx in self.table_index.values() {
432            for edge in self.graph.edges(table_idx) {
433                if let SchemaEdge::ForeignKey {
434                    constraint_name,
435                    from_columns,
436                    cascade_delete,
437                    ..
438                } = edge.weight()
439                {
440                    let source = if let SchemaNode::Table { name, .. } = &self.graph[edge.source()]
441                    {
442                        name.clone()
443                    } else {
444                        continue;
445                    };
446                    let target = if let SchemaNode::Table { name, .. } = &self.graph[edge.target()]
447                    {
448                        name.clone()
449                    } else {
450                        continue;
451                    };
452
453                    let label = constraint_name.as_deref().unwrap_or_else(|| {
454                        from_columns.first().map(|s| s.as_str()).unwrap_or("fk")
455                    });
456
457                    let style = if *cascade_delete { "dashed" } else { "solid" };
458
459                    out.push_str(&format!(
460                        "  {} -> {} [label=\"{}\" style=\"{}\"];\n",
461                        sanitise_id(&source),
462                        sanitise_id(&target),
463                        dot_escape(label),
464                        style,
465                    ));
466                }
467            }
468        }
469
470        out.push_str("}\n");
471        out
472    }
473}
474
475// ─────────────────────────────────────────────
476// Helpers
477// ─────────────────────────────────────────────
478
479/// Sanitise a table/column name for use as a Mermaid or DOT identifier.
480/// Both formats disallow spaces, hyphens, and special characters in raw IDs.
481fn sanitise_id(name: &str) -> String {
482    name.chars()
483        .map(|c| {
484            if c.is_alphanumeric() || c == '_' {
485                c
486            } else {
487                '_'
488            }
489        })
490        .collect()
491}
492
493/// Escape a string for embedding inside DOT label double-quotes.
494fn dot_escape(s: &str) -> String {
495    s.replace('\\', "\\\\")
496        .replace('"', "\\\"")
497        .replace('{', "\\{")
498        .replace('}', "\\}")
499        .replace('<', "\\<")
500        .replace('>', "\\>")
501        .replace('|', "\\|")
502}
503
504/// Collapse verbose PostgreSQL type names to a short Mermaid-friendly form.
505fn mermaid_type(pg_type: &str) -> &str {
506    let lower = pg_type.to_lowercase();
507    if lower.contains("bigint") || lower.contains("int8") {
508        "bigint"
509    } else if lower.contains("int") {
510        "int"
511    } else if lower.contains("bool") {
512        "boolean"
513    } else if lower.contains("text") || lower.contains("varchar") || lower.contains("char") {
514        "string"
515    } else if lower.contains("timestamp") || lower.contains("date") {
516        "datetime"
517    } else if lower.contains("uuid") {
518        "uuid"
519    } else if lower.contains("json") {
520        "json"
521    } else if lower.contains("float")
522        || lower.contains("real")
523        || lower.contains("double")
524        || lower.contains("numeric")
525        || lower.contains("decimal")
526    {
527        "float"
528    } else if lower.contains("bytea") {
529        "bytes"
530    } else {
531        "string"
532    }
533}
534
535/// Human-readable row count (e.g. 1_400_000 → "1.4M").
536fn human_rows(n: u64) -> String {
537    if n >= 1_000_000_000 {
538        format!("{:.1}B", n as f64 / 1_000_000_000.0)
539    } else if n >= 1_000_000 {
540        format!("{:.1}M", n as f64 / 1_000_000.0)
541    } else if n >= 1_000 {
542        format!("{:.1}K", n as f64 / 1_000.0)
543    } else {
544        n.to_string()
545    }
546}
547
548impl Default for SchemaGraph {
549    fn default() -> Self {
550        Self::new()
551    }
552}