Skip to main content

sql_splitter/schema/
graph.rs

1//! Schema dependency graph for FK-aware operations.
2//!
3//! Provides:
4//! - Dependency graph construction from schema FK relationships
5//! - Topological sorting for processing order
6//! - Cycle detection for handling circular FK relationships
7
8use super::{Schema, TableId};
9use std::collections::VecDeque;
10
11/// Schema dependency graph built from foreign key relationships.
12///
13/// The graph represents parent → child relationships where:
14/// - A parent is a table referenced by another table's FK
15/// - A child is a table that has an FK referencing another table
16///
17/// This ordering allows processing parents before children, ensuring
18/// that when sampling/filtering children, parent data is already available.
19#[derive(Debug)]
20pub struct SchemaGraph {
21    /// The underlying schema
22    pub schema: Schema,
23    /// For each table, list of parent tables (tables this table references via FK)
24    pub parents: Vec<Vec<TableId>>,
25    /// For each table, list of child tables (tables that reference this table via FK)
26    pub children: Vec<Vec<TableId>>,
27}
28
29/// Result of topological sort
30#[derive(Debug)]
31pub struct TopoSortResult {
32    /// Tables in topological order (parents before children)
33    pub order: Vec<TableId>,
34    /// Tables that are part of cycles (could not be ordered)
35    pub cyclic_tables: Vec<TableId>,
36}
37
38impl SchemaGraph {
39    /// Build a dependency graph from a schema
40    pub fn from_schema(schema: Schema) -> Self {
41        let n = schema.table_schemas.len();
42        let mut parents: Vec<Vec<TableId>> = vec![Vec::new(); n];
43        let mut children: Vec<Vec<TableId>> = vec![Vec::new(); n];
44
45        for table in &schema.table_schemas {
46            let child_id = table.id;
47
48            for fk in &table.foreign_keys {
49                if let Some(parent_id) = fk.referenced_table_id {
50                    // Avoid self-references in the graph (handle separately)
51                    if parent_id != child_id {
52                        // Child depends on parent
53                        if !parents[child_id.0 as usize].contains(&parent_id) {
54                            parents[child_id.0 as usize].push(parent_id);
55                        }
56                        // Parent has child dependent
57                        if !children[parent_id.0 as usize].contains(&child_id) {
58                            children[parent_id.0 as usize].push(child_id);
59                        }
60                    }
61                }
62            }
63        }
64
65        Self {
66            schema,
67            parents,
68            children,
69        }
70    }
71
72    /// Get the number of tables in the graph
73    pub fn len(&self) -> usize {
74        self.schema.len()
75    }
76
77    /// Check if the graph is empty
78    pub fn is_empty(&self) -> bool {
79        self.schema.is_empty()
80    }
81
82    /// Get the table name for a table ID
83    pub fn table_name(&self, id: TableId) -> Option<&str> {
84        self.schema.table(id).map(|t| t.name.as_str())
85    }
86
87    /// Check if a table has a self-referential FK
88    pub fn has_self_reference(&self, id: TableId) -> bool {
89        self.schema
90            .table(id)
91            .map(|t| {
92                t.foreign_keys
93                    .iter()
94                    .any(|fk| fk.referenced_table_id == Some(id))
95            })
96            .unwrap_or(false)
97    }
98
99    /// Get tables that have self-referential FKs
100    pub fn self_referential_tables(&self) -> Vec<TableId> {
101        (0..self.len())
102            .map(|i| TableId(i as u32))
103            .filter(|&id| self.has_self_reference(id))
104            .collect()
105    }
106
107    /// Perform topological sort using Kahn's algorithm.
108    ///
109    /// Returns tables in dependency order (parents before children).
110    /// Tables that are part of cycles are returned separately.
111    pub fn topo_sort(&self) -> TopoSortResult {
112        let n = self.len();
113        if n == 0 {
114            return TopoSortResult {
115                order: Vec::new(),
116                cyclic_tables: Vec::new(),
117            };
118        }
119
120        // Calculate in-degrees (number of parents for each table)
121        let mut in_degree: Vec<usize> = vec![0; n];
122        for (i, parents) in self.parents.iter().enumerate() {
123            in_degree[i] = parents.len();
124        }
125
126        // Start with tables that have no parents (roots)
127        let mut queue: VecDeque<TableId> = VecDeque::new();
128        for (i, &deg) in in_degree.iter().enumerate() {
129            if deg == 0 {
130                queue.push_back(TableId(i as u32));
131            }
132        }
133
134        let mut order = Vec::with_capacity(n);
135
136        while let Some(table_id) = queue.pop_front() {
137            order.push(table_id);
138
139            // Reduce in-degree of all children
140            for &child_id in &self.children[table_id.0 as usize] {
141                in_degree[child_id.0 as usize] -= 1;
142                if in_degree[child_id.0 as usize] == 0 {
143                    queue.push_back(child_id);
144                }
145            }
146        }
147
148        // Tables with remaining in-degree > 0 are part of cycles
149        let cyclic_tables: Vec<TableId> = in_degree
150            .iter()
151            .enumerate()
152            .filter(|(_, &deg)| deg > 0)
153            .map(|(i, _)| TableId(i as u32))
154            .collect();
155
156        TopoSortResult {
157            order,
158            cyclic_tables,
159        }
160    }
161
162    /// Get processing order for sampling/sharding.
163    ///
164    /// Returns all tables in order: first the topologically sorted acyclic tables,
165    /// then the cyclic tables (which need special handling).
166    pub fn processing_order(&self) -> (Vec<TableId>, Vec<TableId>) {
167        let result = self.topo_sort();
168        (result.order, result.cyclic_tables)
169    }
170
171    /// Check if table A is an ancestor of table B (A is referenced by B directly or transitively)
172    pub fn is_ancestor(&self, ancestor: TableId, descendant: TableId) -> bool {
173        if ancestor == descendant {
174            return false;
175        }
176
177        let mut visited = vec![false; self.len()];
178        let mut queue = VecDeque::new();
179        queue.push_back(descendant);
180
181        while let Some(current) = queue.pop_front() {
182            for &parent in &self.parents[current.0 as usize] {
183                if parent == ancestor {
184                    return true;
185                }
186                if !visited[parent.0 as usize] {
187                    visited[parent.0 as usize] = true;
188                    queue.push_back(parent);
189                }
190            }
191        }
192
193        false
194    }
195
196    /// Get all ancestor tables of a given table (tables it depends on, directly or transitively)
197    pub fn ancestors(&self, id: TableId) -> Vec<TableId> {
198        let mut ancestors = Vec::new();
199        let mut visited = vec![false; self.len()];
200        let mut queue = VecDeque::new();
201
202        for &parent in &self.parents[id.0 as usize] {
203            queue.push_back(parent);
204            visited[parent.0 as usize] = true;
205        }
206
207        while let Some(current) = queue.pop_front() {
208            ancestors.push(current);
209            for &parent in &self.parents[current.0 as usize] {
210                if !visited[parent.0 as usize] {
211                    visited[parent.0 as usize] = true;
212                    queue.push_back(parent);
213                }
214            }
215        }
216
217        ancestors
218    }
219
220    /// Get all descendant tables of a given table (tables that depend on it)
221    pub fn descendants(&self, id: TableId) -> Vec<TableId> {
222        let mut descendants = Vec::new();
223        let mut visited = vec![false; self.len()];
224        let mut queue = VecDeque::new();
225
226        for &child in &self.children[id.0 as usize] {
227            queue.push_back(child);
228            visited[child.0 as usize] = true;
229        }
230
231        while let Some(current) = queue.pop_front() {
232            descendants.push(current);
233            for &child in &self.children[current.0 as usize] {
234                if !visited[child.0 as usize] {
235                    visited[child.0 as usize] = true;
236                    queue.push_back(child);
237                }
238            }
239        }
240
241        descendants
242    }
243
244    /// Get root tables (tables with no parents/dependencies)
245    pub fn root_tables(&self) -> Vec<TableId> {
246        self.parents
247            .iter()
248            .enumerate()
249            .filter(|(_, parents)| parents.is_empty())
250            .map(|(i, _)| TableId(i as u32))
251            .collect()
252    }
253
254    /// Get leaf tables (tables with no children/dependents)
255    pub fn leaf_tables(&self) -> Vec<TableId> {
256        self.children
257            .iter()
258            .enumerate()
259            .filter(|(_, children)| children.is_empty())
260            .map(|(i, _)| TableId(i as u32))
261            .collect()
262    }
263}