sql_splitter/schema/
mod.rs1mod ddl;
10mod graph;
11
12pub use ddl::*;
13pub use graph::*;
14
15use ahash::AHashMap;
16use std::fmt;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct TableId(pub u32);
21
22impl fmt::Display for TableId {
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 write!(f, "TableId({})", self.0)
25 }
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
30pub struct ColumnId(pub u16);
31
32impl fmt::Display for ColumnId {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, "ColumnId({})", self.0)
35 }
36}
37
38#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum ColumnType {
41 Int,
43 BigInt,
45 Text,
47 Uuid,
49 Decimal,
51 DateTime,
53 Bool,
55 Other(String),
57}
58
59impl ColumnType {
60 pub fn from_sql_type(type_str: &str) -> Self {
63 let type_lower = type_str.to_lowercase();
64 let base_type = type_lower.split('(').next().unwrap_or(&type_lower).trim();
65
66 match base_type {
67 "int" | "integer" | "tinyint" | "smallint" | "mediumint" | "int4" | "int2" => {
69 ColumnType::Int
70 }
71 "serial" | "smallserial" => ColumnType::Int,
73 "bigint" | "int8" | "bigserial" => ColumnType::BigInt,
74 "char" | "varchar" | "text" | "tinytext" | "mediumtext" | "longtext" | "enum"
76 | "set" | "character" => ColumnType::Text,
77 "decimal" | "numeric" | "float" | "double" | "real" | "float4" | "float8" | "money" => {
79 ColumnType::Decimal
80 }
81 "date" | "datetime" | "timestamp" | "time" | "year" | "timestamptz" | "timetz"
83 | "interval" => ColumnType::DateTime,
84 "bool" | "boolean" => ColumnType::Bool,
86 "binary" | "varbinary" | "blob" | "bytea" => {
88 if type_lower.contains("16") {
90 ColumnType::Uuid
91 } else {
92 ColumnType::Other(type_str.to_string())
93 }
94 }
95 "uuid" => ColumnType::Uuid,
96 _ => ColumnType::Other(type_str.to_string()),
97 }
98 }
99
100 pub fn from_mysql_type(type_str: &str) -> Self {
102 Self::from_sql_type(type_str)
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct Column {
109 pub name: String,
111 pub col_type: ColumnType,
113 pub ordinal: ColumnId,
115 pub is_primary_key: bool,
117 pub is_nullable: bool,
119}
120
121#[derive(Debug, Clone)]
123pub struct ForeignKey {
124 pub name: Option<String>,
126 pub columns: Vec<ColumnId>,
128 pub column_names: Vec<String>,
130 pub referenced_table: String,
132 pub referenced_columns: Vec<String>,
134 pub referenced_table_id: Option<TableId>,
136}
137
138#[derive(Debug, Clone)]
140pub struct TableSchema {
141 pub name: String,
143 pub id: TableId,
145 pub columns: Vec<Column>,
147 pub primary_key: Vec<ColumnId>,
149 pub foreign_keys: Vec<ForeignKey>,
151 pub create_statement: Option<String>,
153}
154
155impl TableSchema {
156 pub fn new(name: String, id: TableId) -> Self {
158 Self {
159 name,
160 id,
161 columns: Vec::new(),
162 primary_key: Vec::new(),
163 foreign_keys: Vec::new(),
164 create_statement: None,
165 }
166 }
167
168 pub fn get_column(&self, name: &str) -> Option<&Column> {
170 self.columns
171 .iter()
172 .find(|c| c.name.eq_ignore_ascii_case(name))
173 }
174
175 pub fn get_column_id(&self, name: &str) -> Option<ColumnId> {
177 self.get_column(name).map(|c| c.ordinal)
178 }
179
180 pub fn column(&self, id: ColumnId) -> Option<&Column> {
182 self.columns.get(id.0 as usize)
183 }
184
185 pub fn is_pk_column(&self, col_id: ColumnId) -> bool {
187 self.primary_key.contains(&col_id)
188 }
189
190 pub fn fk_column_ids(&self) -> Vec<ColumnId> {
192 self.foreign_keys
193 .iter()
194 .flat_map(|fk| fk.columns.iter().copied())
195 .collect()
196 }
197}
198
199#[derive(Debug)]
201pub struct Schema {
202 pub tables: AHashMap<String, TableId>,
204 pub table_schemas: Vec<TableSchema>,
206}
207
208impl Schema {
209 pub fn new() -> Self {
211 Self {
212 tables: AHashMap::new(),
213 table_schemas: Vec::new(),
214 }
215 }
216
217 pub fn get_table_id(&self, name: &str) -> Option<TableId> {
219 if let Some(&id) = self.tables.get(name) {
221 return Some(id);
222 }
223 let name_lower = name.to_lowercase();
225 self.tables
226 .iter()
227 .find(|(k, _)| k.to_lowercase() == name_lower)
228 .map(|(_, &id)| id)
229 }
230
231 pub fn table(&self, id: TableId) -> Option<&TableSchema> {
233 self.table_schemas.get(id.0 as usize)
234 }
235
236 pub fn table_mut(&mut self, id: TableId) -> Option<&mut TableSchema> {
238 self.table_schemas.get_mut(id.0 as usize)
239 }
240
241 pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
243 self.get_table_id(name).and_then(|id| self.table(id))
244 }
245
246 pub fn add_table(&mut self, mut schema: TableSchema) -> TableId {
248 let id = TableId(self.table_schemas.len() as u32);
249 schema.id = id;
250 self.tables.insert(schema.name.clone(), id);
251 self.table_schemas.push(schema);
252 id
253 }
254
255 pub fn resolve_foreign_keys(&mut self) {
257 let table_ids: AHashMap<String, TableId> = self.tables.clone();
258
259 for table in &mut self.table_schemas {
260 for fk in &mut table.foreign_keys {
261 fk.referenced_table_id = table_ids
262 .get(&fk.referenced_table)
263 .or_else(|| {
264 let lower = fk.referenced_table.to_lowercase();
266 table_ids
267 .iter()
268 .find(|(k, _)| k.to_lowercase() == lower)
269 .map(|(_, v)| v)
270 })
271 .copied();
272 }
273 }
274 }
275
276 pub fn len(&self) -> usize {
278 self.table_schemas.len()
279 }
280
281 pub fn is_empty(&self) -> bool {
283 self.table_schemas.is_empty()
284 }
285
286 pub fn iter(&self) -> impl Iterator<Item = &TableSchema> {
288 self.table_schemas.iter()
289 }
290}
291
292impl Default for Schema {
293 fn default() -> Self {
294 Self::new()
295 }
296}