Skip to main content

sqlx_gen/introspect/
mysql.rs

1use crate::error::Result;
2use sqlx::MySqlPool;
3
4use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
5
6pub async fn introspect(
7    pool: &MySqlPool,
8    schemas: &[String],
9    include_views: bool,
10) -> Result<SchemaInfo> {
11    let tables = fetch_tables(pool, schemas).await?;
12    let views = if include_views {
13        fetch_views(pool, schemas).await?
14    } else {
15        Vec::new()
16    };
17    let enums = extract_enums(&tables);
18
19    Ok(SchemaInfo {
20        tables,
21        views,
22        enums,
23        composite_types: Vec::new(),
24        domains: Vec::new(),
25    })
26}
27
28async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
29    // MySQL doesn't support binding arrays directly, so we build placeholders
30    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
31    let query = format!(
32        r#"
33        SELECT
34            c.TABLE_SCHEMA,
35            c.TABLE_NAME,
36            c.COLUMN_NAME,
37            c.DATA_TYPE,
38            c.COLUMN_TYPE,
39            c.IS_NULLABLE,
40            c.ORDINAL_POSITION,
41            c.COLUMN_KEY
42        FROM information_schema.COLUMNS c
43        JOIN information_schema.TABLES t
44            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
45            AND t.TABLE_NAME = c.TABLE_NAME
46            AND t.TABLE_TYPE = 'BASE TABLE'
47        WHERE c.TABLE_SCHEMA IN ({})
48        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
49        "#,
50        placeholders.join(",")
51    );
52
53    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
54    for schema in schemas {
55        q = q.bind(schema);
56    }
57    let rows = q.fetch_all(pool).await?;
58
59    let mut tables: Vec<TableInfo> = Vec::new();
60    let mut current_key: Option<(String, String)> = None;
61
62    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
63        let key = (schema.clone(), table.clone());
64        if current_key.as_ref() != Some(&key) {
65            current_key = Some(key);
66            tables.push(TableInfo {
67                schema_name: schema.clone(),
68                name: table.clone(),
69                columns: Vec::new(),
70            });
71        }
72        tables.last_mut().unwrap().columns.push(ColumnInfo {
73            name: col_name,
74            data_type,
75            udt_name: column_type,
76            is_nullable: nullable == "YES",
77            is_primary_key: column_key == "PRI",
78            ordinal_position: ordinal as i32,
79            schema_name: schema,
80            column_default: None,
81        });
82    }
83
84    Ok(tables)
85}
86
87async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
88    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
89    let query = format!(
90        r#"
91        SELECT
92            c.TABLE_SCHEMA,
93            c.TABLE_NAME,
94            c.COLUMN_NAME,
95            c.DATA_TYPE,
96            c.COLUMN_TYPE,
97            c.IS_NULLABLE,
98            c.ORDINAL_POSITION
99        FROM information_schema.COLUMNS c
100        JOIN information_schema.TABLES t
101            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
102            AND t.TABLE_NAME = c.TABLE_NAME
103            AND t.TABLE_TYPE = 'VIEW'
104        WHERE c.TABLE_SCHEMA IN ({})
105        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
106        "#,
107        placeholders.join(",")
108    );
109
110    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
111    for schema in schemas {
112        q = q.bind(schema);
113    }
114    let rows = q.fetch_all(pool).await?;
115
116    let mut views: Vec<TableInfo> = Vec::new();
117    let mut current_key: Option<(String, String)> = None;
118
119    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
120        let key = (schema.clone(), table.clone());
121        if current_key.as_ref() != Some(&key) {
122            current_key = Some(key);
123            views.push(TableInfo {
124                schema_name: schema.clone(),
125                name: table.clone(),
126                columns: Vec::new(),
127            });
128        }
129        views.last_mut().unwrap().columns.push(ColumnInfo {
130            name: col_name,
131            data_type,
132            udt_name: column_type,
133            is_nullable: nullable == "YES",
134            is_primary_key: false,
135            ordinal_position: ordinal as i32,
136            schema_name: schema,
137            column_default: None,
138        });
139    }
140
141    Ok(views)
142}
143
144/// Extract inline ENUMs from column types.
145/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
146/// keyed by table_name + column_name.
147fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
148    let mut enums = Vec::new();
149
150    for table in tables {
151        for col in &table.columns {
152            if col.udt_name.starts_with("enum(") {
153                let variants = parse_enum_variants(&col.udt_name);
154                if !variants.is_empty() {
155                    let enum_name = format!("{}_{}", table.name, col.name);
156                    enums.push(EnumInfo {
157                        schema_name: table.schema_name.clone(),
158                        name: enum_name,
159                        variants,
160                        default_variant: None,
161                    });
162                }
163            }
164        }
165    }
166
167    enums
168}
169
170fn parse_enum_variants(column_type: &str) -> Vec<String> {
171    // Parse "enum('a','b','c')" → ["a", "b", "c"]
172    let inner = column_type
173        .strip_prefix("enum(")
174        .and_then(|s| s.strip_suffix(')'));
175    match inner {
176        Some(s) => s
177            .split(',')
178            .map(|v| v.trim().trim_matches('\'').to_string())
179            .filter(|v| !v.is_empty())
180            .collect(),
181        None => Vec::new(),
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
190        TableInfo {
191            schema_name: "test_db".to_string(),
192            name: name.to_string(),
193            columns,
194        }
195    }
196
197    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
198        ColumnInfo {
199            name: name.to_string(),
200            data_type: "varchar".to_string(),
201            udt_name: udt_name.to_string(),
202            is_nullable: false,
203            is_primary_key: false,
204            ordinal_position: 0,
205            schema_name: "test_db".to_string(),
206            column_default: None,
207        }
208    }
209
210    // ========== parse_enum_variants ==========
211
212    #[test]
213    fn test_parse_simple() {
214        assert_eq!(
215            parse_enum_variants("enum('a','b','c')"),
216            vec!["a", "b", "c"]
217        );
218    }
219
220    #[test]
221    fn test_parse_single_variant() {
222        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
223    }
224
225    #[test]
226    fn test_parse_with_spaces() {
227        assert_eq!(
228            parse_enum_variants("enum( 'a' , 'b' )"),
229            vec!["a", "b"]
230        );
231    }
232
233    #[test]
234    fn test_parse_empty_parens() {
235        let result = parse_enum_variants("enum()");
236        assert!(result.is_empty());
237    }
238
239    #[test]
240    fn test_parse_varchar_not_enum() {
241        let result = parse_enum_variants("varchar(255)");
242        assert!(result.is_empty());
243    }
244
245    #[test]
246    fn test_parse_int_not_enum() {
247        let result = parse_enum_variants("int");
248        assert!(result.is_empty());
249    }
250
251    #[test]
252    fn test_parse_with_spaces_in_value() {
253        assert_eq!(
254            parse_enum_variants("enum('with space','no')"),
255            vec!["with space", "no"]
256        );
257    }
258
259    #[test]
260    fn test_parse_empty_variant_filtered() {
261        let result = parse_enum_variants("enum('a','','c')");
262        assert_eq!(result, vec!["a", "c"]);
263    }
264
265    #[test]
266    fn test_parse_uppercase_enum_not_matched() {
267        // "ENUM(" doesn't match "enum(" prefix
268        let result = parse_enum_variants("ENUM('a','b')");
269        assert!(result.is_empty());
270    }
271
272    // ========== extract_enums ==========
273
274    #[test]
275    fn test_extract_from_enum_column() {
276        let tables = vec![make_table(
277            "users",
278            vec![make_col("status", "enum('active','inactive')")],
279        )];
280        let enums = extract_enums(&tables);
281        assert_eq!(enums.len(), 1);
282        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
283    }
284
285    #[test]
286    fn test_extract_enum_name_format() {
287        let tables = vec![make_table(
288            "users",
289            vec![make_col("status", "enum('a')")],
290        )];
291        let enums = extract_enums(&tables);
292        assert_eq!(enums[0].name, "users_status");
293    }
294
295    #[test]
296    fn test_extract_no_enums() {
297        let tables = vec![make_table(
298            "users",
299            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
300        )];
301        let enums = extract_enums(&tables);
302        assert!(enums.is_empty());
303    }
304
305    #[test]
306    fn test_extract_two_enum_columns_same_table() {
307        let tables = vec![make_table(
308            "users",
309            vec![
310                make_col("status", "enum('active','inactive')"),
311                make_col("role", "enum('admin','user')"),
312            ],
313        )];
314        let enums = extract_enums(&tables);
315        assert_eq!(enums.len(), 2);
316        assert_eq!(enums[0].name, "users_status");
317        assert_eq!(enums[1].name, "users_role");
318    }
319
320    #[test]
321    fn test_extract_enums_from_multiple_tables() {
322        let tables = vec![
323            make_table("users", vec![make_col("status", "enum('a')")]),
324            make_table("posts", vec![make_col("state", "enum('b')")]),
325        ];
326        let enums = extract_enums(&tables);
327        assert_eq!(enums.len(), 2);
328    }
329
330    #[test]
331    fn test_extract_non_enum_column_ignored() {
332        let tables = vec![make_table(
333            "users",
334            vec![
335                make_col("id", "int(11)"),
336                make_col("status", "enum('a')"),
337            ],
338        )];
339        let enums = extract_enums(&tables);
340        assert_eq!(enums.len(), 1);
341    }
342}