1use crate::schema::{ColumnType, SchemaGraph};
4use ahash::{AHashMap, AHashSet};
5use glob::Pattern;
6use std::collections::VecDeque;
7
8#[derive(Debug, Clone)]
10pub struct ColumnInfo {
11 pub name: String,
13 pub col_type: String,
15 pub is_primary_key: bool,
17 pub is_foreign_key: bool,
19 pub is_nullable: bool,
21 pub references_table: Option<String>,
23 pub references_column: Option<String>,
25}
26
27#[derive(Debug, Clone)]
29pub struct TableInfo {
30 pub name: String,
32 pub columns: Vec<ColumnInfo>,
34}
35
36#[derive(Debug, Clone)]
38pub struct EdgeInfo {
39 pub from_table: String,
41 pub from_column: String,
43 pub to_table: String,
45 pub to_column: String,
47 pub cardinality: Cardinality,
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
53pub enum Cardinality {
54 #[default]
55 ManyToOne, OneToOne,
57 OneToMany,
58 ManyToMany,
59}
60
61impl Cardinality {
62 pub fn as_mermaid(self) -> &'static str {
64 match self {
65 Cardinality::ManyToOne => "}o--||",
66 Cardinality::OneToOne => "||--||",
67 Cardinality::OneToMany => "||--o{",
68 Cardinality::ManyToMany => "}o--o{",
69 }
70 }
71}
72
73#[derive(Debug)]
75pub struct GraphView {
76 pub tables: AHashMap<String, TableInfo>,
78 pub edges: Vec<EdgeInfo>,
80}
81
82impl GraphView {
83 pub fn from_schema_graph(graph: &SchemaGraph) -> Self {
85 let mut tables = AHashMap::new();
86 let mut edges = Vec::new();
87
88 let mut fk_lookup: AHashMap<(String, String), (String, String)> = AHashMap::new();
90
91 for table_schema in graph.schema.iter() {
92 for fk in &table_schema.foreign_keys {
93 for (i, col_name) in fk.column_names.iter().enumerate() {
94 let ref_col = fk.referenced_columns.get(i).cloned().unwrap_or_default();
95 fk_lookup.insert(
96 (table_schema.name.clone(), col_name.clone()),
97 (fk.referenced_table.clone(), ref_col),
98 );
99 }
100 }
101 }
102
103 for table_schema in graph.schema.iter() {
105 let mut columns = Vec::new();
106
107 for col in &table_schema.columns {
108 let is_fk = fk_lookup.contains_key(&(table_schema.name.clone(), col.name.clone()));
109 let (ref_table, ref_col) = fk_lookup
110 .get(&(table_schema.name.clone(), col.name.clone()))
111 .cloned()
112 .map(|(t, c)| (Some(t), Some(c)))
113 .unwrap_or((None, None));
114
115 columns.push(ColumnInfo {
116 name: col.name.clone(),
117 col_type: format_column_type(&col.col_type),
118 is_primary_key: col.is_primary_key,
119 is_foreign_key: is_fk,
120 is_nullable: col.is_nullable,
121 references_table: ref_table,
122 references_column: ref_col,
123 });
124 }
125
126 tables.insert(
127 table_schema.name.clone(),
128 TableInfo {
129 name: table_schema.name.clone(),
130 columns,
131 },
132 );
133 }
134
135 for table_schema in graph.schema.iter() {
137 for fk in &table_schema.foreign_keys {
138 for (i, col_name) in fk.column_names.iter().enumerate() {
140 let ref_col = fk
141 .referenced_columns
142 .get(i)
143 .cloned()
144 .unwrap_or_else(|| "id".to_string());
145
146 edges.push(EdgeInfo {
147 from_table: table_schema.name.clone(),
148 from_column: col_name.clone(),
149 to_table: fk.referenced_table.clone(),
150 to_column: ref_col,
151 cardinality: Cardinality::ManyToOne,
152 });
153 }
154 }
155 }
156
157 Self { tables, edges }
158 }
159
160 pub fn filter_tables(&mut self, patterns: &[Pattern]) {
162 if patterns.is_empty() {
163 return;
164 }
165
166 let matching: AHashSet<String> = self
167 .tables
168 .keys()
169 .filter(|name| patterns.iter().any(|p| p.matches(name)))
170 .cloned()
171 .collect();
172
173 self.apply_node_filter(&matching);
174 }
175
176 pub fn exclude_tables(&mut self, patterns: &[Pattern]) {
178 if patterns.is_empty() {
179 return;
180 }
181
182 let remaining: AHashSet<String> = self
183 .tables
184 .keys()
185 .filter(|name| !patterns.iter().any(|p| p.matches(name)))
186 .cloned()
187 .collect();
188
189 self.apply_node_filter(&remaining);
190 }
191
192 pub fn focus_table(
194 &mut self,
195 table: &str,
196 transitive: bool,
197 reverse: bool,
198 max_depth: Option<usize>,
199 ) {
200 if !self.tables.contains_key(table) {
201 self.tables.clear();
202 self.edges.clear();
203 return;
204 }
205
206 let mut result_nodes = AHashSet::new();
207 result_nodes.insert(table.to_string());
208
209 let (outgoing, incoming) = self.build_adjacency_maps();
211
212 if transitive {
213 self.traverse(&outgoing, table, max_depth, &mut result_nodes);
215 }
216
217 if reverse {
218 self.traverse(&incoming, table, max_depth, &mut result_nodes);
220 }
221
222 if !transitive && !reverse {
224 if let Some(parents) = outgoing.get(table) {
225 for parent in parents {
226 result_nodes.insert(parent.clone());
227 }
228 }
229 if let Some(children) = incoming.get(table) {
230 for child in children {
231 result_nodes.insert(child.clone());
232 }
233 }
234 }
235
236 self.apply_node_filter(&result_nodes);
237 }
238
239 pub fn filter_to_cyclic_tables(&mut self, cyclic_tables: &AHashSet<String>) {
241 self.apply_node_filter(cyclic_tables);
242 }
243
244 pub fn table_count(&self) -> usize {
246 self.tables.len()
247 }
248
249 pub fn edge_count(&self) -> usize {
251 self.edges.len()
252 }
253
254 pub fn is_empty(&self) -> bool {
256 self.tables.is_empty()
257 }
258
259 pub fn sorted_tables(&self) -> Vec<&TableInfo> {
261 let mut tables: Vec<_> = self.tables.values().collect();
262 tables.sort_by(|a, b| a.name.cmp(&b.name));
263 tables
264 }
265
266 pub fn get_table(&self, name: &str) -> Option<&TableInfo> {
268 self.tables.get(name)
269 }
270
271 fn apply_node_filter(&mut self, keep: &AHashSet<String>) {
274 self.tables.retain(|n, _| keep.contains(n));
275 self.edges
276 .retain(|e| keep.contains(&e.from_table) && keep.contains(&e.to_table));
277 }
278
279 fn build_adjacency_maps(
280 &self,
281 ) -> (AHashMap<String, Vec<String>>, AHashMap<String, Vec<String>>) {
282 let mut outgoing: AHashMap<String, Vec<String>> = AHashMap::new();
283 let mut incoming: AHashMap<String, Vec<String>> = AHashMap::new();
284
285 for edge in &self.edges {
286 outgoing
287 .entry(edge.from_table.clone())
288 .or_default()
289 .push(edge.to_table.clone());
290 incoming
291 .entry(edge.to_table.clone())
292 .or_default()
293 .push(edge.from_table.clone());
294 }
295
296 (outgoing, incoming)
297 }
298
299 fn traverse(
300 &self,
301 adjacency: &AHashMap<String, Vec<String>>,
302 start: &str,
303 max_depth: Option<usize>,
304 result: &mut AHashSet<String>,
305 ) {
306 let mut queue: VecDeque<(String, usize)> = VecDeque::new();
307 queue.push_back((start.to_string(), 0));
308
309 while let Some((current, depth)) = queue.pop_front() {
310 if let Some(max) = max_depth {
311 if depth >= max {
312 continue;
313 }
314 }
315
316 if let Some(neighbors) = adjacency.get(¤t) {
317 for neighbor in neighbors {
318 if result.insert(neighbor.clone()) {
319 queue.push_back((neighbor.clone(), depth + 1));
320 }
321 }
322 }
323 }
324 }
325}
326
327fn format_column_type(col_type: &ColumnType) -> String {
329 match col_type {
330 ColumnType::Int => "INT".to_string(),
331 ColumnType::BigInt => "BIGINT".to_string(),
332 ColumnType::Text => "VARCHAR".to_string(),
333 ColumnType::Uuid => "UUID".to_string(),
334 ColumnType::Decimal => "DECIMAL".to_string(),
335 ColumnType::DateTime => "DATETIME".to_string(),
336 ColumnType::Bool => "BOOL".to_string(),
337 ColumnType::Other(s) => s.to_uppercase(),
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn create_test_view() -> GraphView {
346 let mut tables = AHashMap::new();
347
348 tables.insert(
349 "users".to_string(),
350 TableInfo {
351 name: "users".to_string(),
352 columns: vec![
353 ColumnInfo {
354 name: "id".to_string(),
355 col_type: "INT".to_string(),
356 is_primary_key: true,
357 is_foreign_key: false,
358 is_nullable: false,
359 references_table: None,
360 references_column: None,
361 },
362 ColumnInfo {
363 name: "email".to_string(),
364 col_type: "VARCHAR".to_string(),
365 is_primary_key: false,
366 is_foreign_key: false,
367 is_nullable: false,
368 references_table: None,
369 references_column: None,
370 },
371 ],
372 },
373 );
374
375 tables.insert(
376 "orders".to_string(),
377 TableInfo {
378 name: "orders".to_string(),
379 columns: vec![
380 ColumnInfo {
381 name: "id".to_string(),
382 col_type: "INT".to_string(),
383 is_primary_key: true,
384 is_foreign_key: false,
385 is_nullable: false,
386 references_table: None,
387 references_column: None,
388 },
389 ColumnInfo {
390 name: "user_id".to_string(),
391 col_type: "INT".to_string(),
392 is_primary_key: false,
393 is_foreign_key: true,
394 is_nullable: false,
395 references_table: Some("users".to_string()),
396 references_column: Some("id".to_string()),
397 },
398 ],
399 },
400 );
401
402 let edges = vec![EdgeInfo {
403 from_table: "orders".to_string(),
404 from_column: "user_id".to_string(),
405 to_table: "users".to_string(),
406 to_column: "id".to_string(),
407 cardinality: Cardinality::ManyToOne,
408 }];
409
410 GraphView { tables, edges }
411 }
412
413 #[test]
414 fn test_table_info() {
415 let view = create_test_view();
416 assert_eq!(view.table_count(), 2);
417
418 let users = view.get_table("users").unwrap();
419 assert_eq!(users.columns.len(), 2);
420 assert!(users.columns[0].is_primary_key);
421 }
422
423 #[test]
424 fn test_edge_info() {
425 let view = create_test_view();
426 assert_eq!(view.edge_count(), 1);
427
428 let edge = &view.edges[0];
429 assert_eq!(edge.from_table, "orders");
430 assert_eq!(edge.from_column, "user_id");
431 assert_eq!(edge.to_table, "users");
432 assert_eq!(edge.to_column, "id");
433 }
434
435 #[test]
436 fn test_exclude_tables() {
437 let mut view = create_test_view();
438 let patterns = vec![Pattern::new("orders").unwrap()];
439 view.exclude_tables(&patterns);
440
441 assert!(!view.tables.contains_key("orders"));
442 assert!(view.tables.contains_key("users"));
443 assert_eq!(view.edge_count(), 0); }
445}