sql_splitter/schema/
mod.rs

1//! Schema analysis module for FK-aware operations.
2//!
3//! This module provides:
4//! - Data models for table schemas, columns, and foreign keys
5//! - MySQL DDL parsing for extracting schema information
6//! - Dependency graph construction with topological sorting
7//! - Cycle detection for handling circular FK relationships
8
9mod ddl;
10mod graph;
11
12pub use ddl::*;
13pub use graph::*;
14
15use ahash::AHashMap;
16use std::fmt;
17
18/// Unique identifier for a table within a schema
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TableId(pub u32);
21
22impl fmt::Display for TableId {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        write!(f, "TableId({})", self.0)
25    }
26}
27
28/// Unique identifier for a column within a table
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct ColumnId(pub u16);
31
32impl fmt::Display for ColumnId {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "ColumnId({})", self.0)
35    }
36}
37
38/// SQL column type classification
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum ColumnType {
41    /// Integer types: INT, INTEGER, TINYINT, SMALLINT, MEDIUMINT
42    Int,
43    /// Big integer types: BIGINT
44    BigInt,
45    /// Text types: CHAR, VARCHAR, TEXT, etc.
46    Text,
47    /// UUID types (detected by column name or type)
48    Uuid,
49    /// Decimal/numeric types
50    Decimal,
51    /// Date/time types
52    DateTime,
53    /// Boolean type
54    Bool,
55    /// Any other type
56    Other(String),
57}
58
59impl ColumnType {
60    /// Parse a SQL type string into a ColumnType
61    /// Supports MySQL, PostgreSQL, and SQLite types
62    pub fn from_sql_type(type_str: &str) -> Self {
63        let type_lower = type_str.to_lowercase();
64        let base_type = type_lower.split('(').next().unwrap_or(&type_lower).trim();
65
66        match base_type {
67            // Integer types (all dialects)
68            "int" | "integer" | "tinyint" | "smallint" | "mediumint" | "int4" | "int2" => {
69                ColumnType::Int
70            }
71            // Auto-increment integer types (PostgreSQL)
72            "serial" | "smallserial" => ColumnType::Int,
73            "bigint" | "int8" | "bigserial" => ColumnType::BigInt,
74            // Text types (all dialects)
75            "char" | "varchar" | "text" | "tinytext" | "mediumtext" | "longtext" | "enum"
76            | "set" | "character" => ColumnType::Text,
77            // Decimal types (all dialects)
78            "decimal" | "numeric" | "float" | "double" | "real" | "float4" | "float8" | "money" => {
79                ColumnType::Decimal
80            }
81            // Date/time types (all dialects)
82            "date" | "datetime" | "timestamp" | "time" | "year" | "timestamptz" | "timetz"
83            | "interval" => ColumnType::DateTime,
84            // Boolean (all dialects)
85            "bool" | "boolean" => ColumnType::Bool,
86            // Binary types
87            "binary" | "varbinary" | "blob" | "bytea" => {
88                // Could be UUID if binary(16)
89                if type_lower.contains("16") {
90                    ColumnType::Uuid
91                } else {
92                    ColumnType::Other(type_str.to_string())
93                }
94            }
95            "uuid" => ColumnType::Uuid,
96            _ => ColumnType::Other(type_str.to_string()),
97        }
98    }
99
100    /// Parse a MySQL type string into a ColumnType (alias for from_sql_type)
101    pub fn from_mysql_type(type_str: &str) -> Self {
102        Self::from_sql_type(type_str)
103    }
104}
105
106/// Column definition within a table
107#[derive(Debug, Clone)]
108pub struct Column {
109    /// Column name
110    pub name: String,
111    /// Column type
112    pub col_type: ColumnType,
113    /// Position in table (0-indexed)
114    pub ordinal: ColumnId,
115    /// Whether this column is part of the primary key
116    pub is_primary_key: bool,
117    /// Whether this column allows NULL values
118    pub is_nullable: bool,
119}
120
121/// Foreign key constraint definition
122#[derive(Debug, Clone)]
123pub struct ForeignKey {
124    /// Constraint name (optional)
125    pub name: Option<String>,
126    /// Column IDs in this table that form the FK
127    pub columns: Vec<ColumnId>,
128    /// Column names in this table (before resolution)
129    pub column_names: Vec<String>,
130    /// Referenced table name
131    pub referenced_table: String,
132    /// Referenced column names
133    pub referenced_columns: Vec<String>,
134    /// Resolved referenced table ID (set after schema is complete)
135    pub referenced_table_id: Option<TableId>,
136}
137
138/// Complete table schema definition
139#[derive(Debug, Clone)]
140pub struct TableSchema {
141    /// Table name
142    pub name: String,
143    /// Table ID within the schema
144    pub id: TableId,
145    /// Column definitions in order
146    pub columns: Vec<Column>,
147    /// Primary key column IDs (ordered for composite PKs)
148    pub primary_key: Vec<ColumnId>,
149    /// Foreign key constraints
150    pub foreign_keys: Vec<ForeignKey>,
151    /// Raw CREATE TABLE statement (for output)
152    pub create_statement: Option<String>,
153}
154
155impl TableSchema {
156    /// Create a new empty table schema
157    pub fn new(name: String, id: TableId) -> Self {
158        Self {
159            name,
160            id,
161            columns: Vec::new(),
162            primary_key: Vec::new(),
163            foreign_keys: Vec::new(),
164            create_statement: None,
165        }
166    }
167
168    /// Get a column by name
169    pub fn get_column(&self, name: &str) -> Option<&Column> {
170        self.columns
171            .iter()
172            .find(|c| c.name.eq_ignore_ascii_case(name))
173    }
174
175    /// Get column ID by name
176    pub fn get_column_id(&self, name: &str) -> Option<ColumnId> {
177        self.get_column(name).map(|c| c.ordinal)
178    }
179
180    /// Get column by ID
181    pub fn column(&self, id: ColumnId) -> Option<&Column> {
182        self.columns.get(id.0 as usize)
183    }
184
185    /// Check if column is part of the primary key
186    pub fn is_pk_column(&self, col_id: ColumnId) -> bool {
187        self.primary_key.contains(&col_id)
188    }
189
190    /// Get all FK column IDs (columns that reference other tables)
191    pub fn fk_column_ids(&self) -> Vec<ColumnId> {
192        self.foreign_keys
193            .iter()
194            .flat_map(|fk| fk.columns.iter().copied())
195            .collect()
196    }
197}
198
199/// Complete database schema
200#[derive(Debug)]
201pub struct Schema {
202    /// Map from table name to table ID
203    pub tables: AHashMap<String, TableId>,
204    /// Table schemas indexed by TableId
205    pub table_schemas: Vec<TableSchema>,
206}
207
208impl Schema {
209    /// Create a new empty schema
210    pub fn new() -> Self {
211        Self {
212            tables: AHashMap::new(),
213            table_schemas: Vec::new(),
214        }
215    }
216
217    /// Get table ID by name (case-insensitive)
218    pub fn get_table_id(&self, name: &str) -> Option<TableId> {
219        // Try exact match first
220        if let Some(&id) = self.tables.get(name) {
221            return Some(id);
222        }
223        // Try case-insensitive match
224        let name_lower = name.to_lowercase();
225        self.tables
226            .iter()
227            .find(|(k, _)| k.to_lowercase() == name_lower)
228            .map(|(_, &id)| id)
229    }
230
231    /// Get table schema by ID
232    pub fn table(&self, id: TableId) -> Option<&TableSchema> {
233        self.table_schemas.get(id.0 as usize)
234    }
235
236    /// Get mutable table schema by ID
237    pub fn table_mut(&mut self, id: TableId) -> Option<&mut TableSchema> {
238        self.table_schemas.get_mut(id.0 as usize)
239    }
240
241    /// Get table schema by name
242    pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
243        self.get_table_id(name).and_then(|id| self.table(id))
244    }
245
246    /// Add a new table schema, returning its ID
247    pub fn add_table(&mut self, mut schema: TableSchema) -> TableId {
248        let id = TableId(self.table_schemas.len() as u32);
249        schema.id = id;
250        self.tables.insert(schema.name.clone(), id);
251        self.table_schemas.push(schema);
252        id
253    }
254
255    /// Resolve all foreign key references to table IDs
256    pub fn resolve_foreign_keys(&mut self) {
257        let table_ids: AHashMap<String, TableId> = self.tables.clone();
258
259        for table in &mut self.table_schemas {
260            for fk in &mut table.foreign_keys {
261                fk.referenced_table_id = table_ids
262                    .get(&fk.referenced_table)
263                    .or_else(|| {
264                        // Case-insensitive fallback
265                        let lower = fk.referenced_table.to_lowercase();
266                        table_ids
267                            .iter()
268                            .find(|(k, _)| k.to_lowercase() == lower)
269                            .map(|(_, v)| v)
270                    })
271                    .copied();
272            }
273        }
274    }
275
276    /// Get the number of tables
277    pub fn len(&self) -> usize {
278        self.table_schemas.len()
279    }
280
281    /// Check if schema is empty
282    pub fn is_empty(&self) -> bool {
283        self.table_schemas.is_empty()
284    }
285
286    /// Iterate over all table schemas
287    pub fn iter(&self) -> impl Iterator<Item = &TableSchema> {
288        self.table_schemas.iter()
289    }
290}
291
292impl Default for Schema {
293    fn default() -> Self {
294        Self::new()
295    }
296}