Skip to main content

sql_splitter/schema/
mod.rs

1//! Schema analysis module for FK-aware operations.
2//!
3//! This module provides:
4//! - Data models for table schemas, columns, and foreign keys
5//! - MySQL DDL parsing for extracting schema information
6//! - Dependency graph construction with topological sorting
7//! - Cycle detection for handling circular FK relationships
8
9mod ddl;
10mod graph;
11
12pub use ddl::*;
13pub use graph::*;
14
15use ahash::AHashMap;
16use std::fmt;
17
18/// Unique identifier for a table within a schema
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TableId(pub u32);
21
22impl fmt::Display for TableId {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        write!(f, "TableId({})", self.0)
25    }
26}
27
28/// Unique identifier for a column within a table
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct ColumnId(pub u16);
31
32impl fmt::Display for ColumnId {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "ColumnId({})", self.0)
35    }
36}
37
38/// SQL column type classification
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum ColumnType {
41    /// Integer types: INT, INTEGER, TINYINT, SMALLINT, MEDIUMINT
42    Int,
43    /// Big integer types: BIGINT
44    BigInt,
45    /// Text types: CHAR, VARCHAR, TEXT, etc.
46    Text,
47    /// UUID types (detected by column name or type)
48    Uuid,
49    /// Decimal/numeric types
50    Decimal,
51    /// Date/time types
52    DateTime,
53    /// Boolean type
54    Bool,
55    /// Any other type
56    Other(String),
57}
58
59impl ColumnType {
60    /// Parse a SQL type string into a ColumnType
61    /// Supports MySQL, PostgreSQL, and SQLite types
62    pub fn from_sql_type(type_str: &str) -> Self {
63        let type_lower = type_str.to_lowercase();
64        let base_type = type_lower.split('(').next().unwrap_or(&type_lower).trim();
65
66        match base_type {
67            // Integer types (all dialects)
68            "int" | "integer" | "tinyint" | "smallint" | "mediumint" | "int4" | "int2" => {
69                ColumnType::Int
70            }
71            // Auto-increment integer types (PostgreSQL)
72            "serial" | "smallserial" => ColumnType::Int,
73            "bigint" | "int8" | "bigserial" => ColumnType::BigInt,
74            // Text types (all dialects)
75            "char" | "varchar" | "text" | "tinytext" | "mediumtext" | "longtext" | "enum"
76            | "set" | "character" => ColumnType::Text,
77            // Decimal types (all dialects)
78            "decimal" | "numeric" | "float" | "double" | "real" | "float4" | "float8" | "money" => {
79                ColumnType::Decimal
80            }
81            // Date/time types (all dialects)
82            "date" | "datetime" | "timestamp" | "time" | "year" | "timestamptz" | "timetz"
83            | "interval" => ColumnType::DateTime,
84            // Boolean (all dialects)
85            "bool" | "boolean" => ColumnType::Bool,
86            // Binary types
87            "binary" | "varbinary" | "blob" | "bytea" => {
88                // Could be UUID if binary(16)
89                if type_lower.contains("16") {
90                    ColumnType::Uuid
91                } else {
92                    ColumnType::Other(type_str.to_string())
93                }
94            }
95            "uuid" => ColumnType::Uuid,
96            _ => ColumnType::Other(type_str.to_string()),
97        }
98    }
99
100    /// Parse a MySQL type string into a ColumnType (alias for from_sql_type)
101    pub fn from_mysql_type(type_str: &str) -> Self {
102        Self::from_sql_type(type_str)
103    }
104}
105
106/// Column definition within a table
107#[derive(Debug, Clone)]
108pub struct Column {
109    /// Column name
110    pub name: String,
111    /// Column type
112    pub col_type: ColumnType,
113    /// Position in table (0-indexed)
114    pub ordinal: ColumnId,
115    /// Whether this column is part of the primary key
116    pub is_primary_key: bool,
117    /// Whether this column allows NULL values
118    pub is_nullable: bool,
119}
120
121/// Index definition
122#[derive(Debug, Clone, PartialEq, Eq)]
123pub struct IndexDef {
124    /// Index name
125    pub name: String,
126    /// Columns in the index
127    pub columns: Vec<String>,
128    /// Whether this is a unique index
129    pub is_unique: bool,
130    /// Index type (BTREE, HASH, GIN, etc.)
131    pub index_type: Option<String>,
132}
133
134/// Foreign key constraint definition
135#[derive(Debug, Clone)]
136pub struct ForeignKey {
137    /// Constraint name (optional)
138    pub name: Option<String>,
139    /// Column IDs in this table that form the FK
140    pub columns: Vec<ColumnId>,
141    /// Column names in this table (before resolution)
142    pub column_names: Vec<String>,
143    /// Referenced table name
144    pub referenced_table: String,
145    /// Referenced column names
146    pub referenced_columns: Vec<String>,
147    /// Resolved referenced table ID (set after schema is complete)
148    pub referenced_table_id: Option<TableId>,
149}
150
151/// Complete table schema definition
152#[derive(Debug, Clone)]
153pub struct TableSchema {
154    /// Table name
155    pub name: String,
156    /// Table ID within the schema
157    pub id: TableId,
158    /// Column definitions in order
159    pub columns: Vec<Column>,
160    /// Primary key column IDs (ordered for composite PKs)
161    pub primary_key: Vec<ColumnId>,
162    /// Foreign key constraints
163    pub foreign_keys: Vec<ForeignKey>,
164    /// Index definitions
165    pub indexes: Vec<IndexDef>,
166    /// Raw CREATE TABLE statement (for output)
167    pub create_statement: Option<String>,
168}
169
170impl TableSchema {
171    /// Create a new empty table schema
172    pub fn new(name: String, id: TableId) -> Self {
173        Self {
174            name,
175            id,
176            columns: Vec::new(),
177            primary_key: Vec::new(),
178            foreign_keys: Vec::new(),
179            indexes: Vec::new(),
180            create_statement: None,
181        }
182    }
183
184    /// Get a column by name
185    pub fn get_column(&self, name: &str) -> Option<&Column> {
186        self.columns
187            .iter()
188            .find(|c| c.name.eq_ignore_ascii_case(name))
189    }
190
191    /// Get column ID by name
192    pub fn get_column_id(&self, name: &str) -> Option<ColumnId> {
193        self.get_column(name).map(|c| c.ordinal)
194    }
195
196    /// Get column by ID
197    pub fn column(&self, id: ColumnId) -> Option<&Column> {
198        self.columns.get(id.0 as usize)
199    }
200
201    /// Check if column is part of the primary key
202    pub fn is_pk_column(&self, col_id: ColumnId) -> bool {
203        self.primary_key.contains(&col_id)
204    }
205
206    /// Get all FK column IDs (columns that reference other tables)
207    pub fn fk_column_ids(&self) -> Vec<ColumnId> {
208        self.foreign_keys
209            .iter()
210            .flat_map(|fk| fk.columns.iter().copied())
211            .collect()
212    }
213}
214
215/// Complete database schema
216#[derive(Debug)]
217pub struct Schema {
218    /// Map from table name to table ID
219    pub tables: AHashMap<String, TableId>,
220    /// Table schemas indexed by TableId
221    pub table_schemas: Vec<TableSchema>,
222}
223
224impl Schema {
225    /// Create a new empty schema
226    pub fn new() -> Self {
227        Self {
228            tables: AHashMap::new(),
229            table_schemas: Vec::new(),
230        }
231    }
232
233    /// Get table ID by name (case-insensitive)
234    pub fn get_table_id(&self, name: &str) -> Option<TableId> {
235        // Try exact match first
236        if let Some(&id) = self.tables.get(name) {
237            return Some(id);
238        }
239        // Try case-insensitive match
240        let name_lower = name.to_lowercase();
241        self.tables
242            .iter()
243            .find(|(k, _)| k.to_lowercase() == name_lower)
244            .map(|(_, &id)| id)
245    }
246
247    /// Get table schema by ID
248    pub fn table(&self, id: TableId) -> Option<&TableSchema> {
249        self.table_schemas.get(id.0 as usize)
250    }
251
252    /// Get mutable table schema by ID
253    pub fn table_mut(&mut self, id: TableId) -> Option<&mut TableSchema> {
254        self.table_schemas.get_mut(id.0 as usize)
255    }
256
257    /// Get table schema by name
258    pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
259        self.get_table_id(name).and_then(|id| self.table(id))
260    }
261
262    /// Add a new table schema, returning its ID
263    pub fn add_table(&mut self, mut schema: TableSchema) -> TableId {
264        let id = TableId(self.table_schemas.len() as u32);
265        schema.id = id;
266        self.tables.insert(schema.name.clone(), id);
267        self.table_schemas.push(schema);
268        id
269    }
270
271    /// Resolve all foreign key references to table IDs
272    pub fn resolve_foreign_keys(&mut self) {
273        let table_ids: AHashMap<String, TableId> = self.tables.clone();
274
275        for table in &mut self.table_schemas {
276            for fk in &mut table.foreign_keys {
277                fk.referenced_table_id = table_ids
278                    .get(&fk.referenced_table)
279                    .or_else(|| {
280                        // Case-insensitive fallback
281                        let lower = fk.referenced_table.to_lowercase();
282                        table_ids
283                            .iter()
284                            .find(|(k, _)| k.to_lowercase() == lower)
285                            .map(|(_, v)| v)
286                    })
287                    .copied();
288            }
289        }
290    }
291
292    /// Get the number of tables
293    pub fn len(&self) -> usize {
294        self.table_schemas.len()
295    }
296
297    /// Check if schema is empty
298    pub fn is_empty(&self) -> bool {
299        self.table_schemas.is_empty()
300    }
301
302    /// Iterate over all table schemas
303    pub fn iter(&self) -> impl Iterator<Item = &TableSchema> {
304        self.table_schemas.iter()
305    }
306}
307
308impl Default for Schema {
309    fn default() -> Self {
310        Self::new()
311    }
312}