Skip to main content

sqlx_gen/introspect/
mysql.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &MySqlPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let sources = fetch_view_column_sources(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &sources, &tables);
23    }
24
25    let enums = extract_enums(&tables);
26
27    Ok(SchemaInfo {
28        tables,
29        views,
30        enums,
31        composite_types: Vec::new(),
32        domains: Vec::new(),
33    })
34}
35
36async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
37    // MySQL doesn't support binding arrays directly, so we build placeholders
38    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
39    let query = format!(
40        r#"
41        SELECT
42            c.TABLE_SCHEMA,
43            c.TABLE_NAME,
44            c.COLUMN_NAME,
45            c.DATA_TYPE,
46            c.COLUMN_TYPE,
47            c.IS_NULLABLE,
48            c.ORDINAL_POSITION,
49            c.COLUMN_KEY
50        FROM information_schema.COLUMNS c
51        JOIN information_schema.TABLES t
52            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
53            AND t.TABLE_NAME = c.TABLE_NAME
54            AND t.TABLE_TYPE = 'BASE TABLE'
55        WHERE c.TABLE_SCHEMA IN ({})
56        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
57        "#,
58        placeholders.join(",")
59    );
60
61    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
62    for schema in schemas {
63        q = q.bind(schema);
64    }
65    let rows = q.fetch_all(pool).await?;
66
67    let mut tables: Vec<TableInfo> = Vec::new();
68    let mut current_key: Option<(String, String)> = None;
69
70    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
71        let key = (schema.clone(), table.clone());
72        if current_key.as_ref() != Some(&key) {
73            current_key = Some(key);
74            tables.push(TableInfo {
75                schema_name: schema.clone(),
76                name: table.clone(),
77                columns: Vec::new(),
78            });
79        }
80        tables.last_mut().unwrap().columns.push(ColumnInfo {
81            name: col_name,
82            data_type,
83            udt_name: column_type,
84            is_nullable: nullable == "YES",
85            is_primary_key: column_key == "PRI",
86            ordinal_position: ordinal as i32,
87            schema_name: schema,
88            column_default: None,
89        });
90    }
91
92    Ok(tables)
93}
94
95async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
96    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
97    let query = format!(
98        r#"
99        SELECT
100            c.TABLE_SCHEMA,
101            c.TABLE_NAME,
102            c.COLUMN_NAME,
103            c.DATA_TYPE,
104            c.COLUMN_TYPE,
105            c.IS_NULLABLE,
106            c.ORDINAL_POSITION
107        FROM information_schema.COLUMNS c
108        JOIN information_schema.TABLES t
109            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
110            AND t.TABLE_NAME = c.TABLE_NAME
111            AND t.TABLE_TYPE = 'VIEW'
112        WHERE c.TABLE_SCHEMA IN ({})
113        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
114        "#,
115        placeholders.join(",")
116    );
117
118    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
119    for schema in schemas {
120        q = q.bind(schema);
121    }
122    let rows = q.fetch_all(pool).await?;
123
124    let mut views: Vec<TableInfo> = Vec::new();
125    let mut current_key: Option<(String, String)> = None;
126
127    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
128        let key = (schema.clone(), table.clone());
129        if current_key.as_ref() != Some(&key) {
130            current_key = Some(key);
131            views.push(TableInfo {
132                schema_name: schema.clone(),
133                name: table.clone(),
134                columns: Vec::new(),
135            });
136        }
137        views.last_mut().unwrap().columns.push(ColumnInfo {
138            name: col_name,
139            data_type,
140            udt_name: column_type,
141            is_nullable: nullable == "YES",
142            is_primary_key: false,
143            ordinal_position: ordinal as i32,
144            schema_name: schema,
145            column_default: None,
146        });
147    }
148
149    Ok(views)
150}
151
152struct ViewColumnSource {
153    view_schema: String,
154    view_name: String,
155    table_schema: String,
156    table_name: String,
157    column_name: String,
158}
159
160async fn fetch_view_column_sources(
161    pool: &MySqlPool,
162    schemas: &[String],
163) -> Result<Vec<ViewColumnSource>> {
164    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
165    let query = format!(
166        r#"
167        SELECT
168            vcu.VIEW_SCHEMA,
169            vcu.VIEW_NAME,
170            vcu.TABLE_SCHEMA,
171            vcu.TABLE_NAME,
172            vcu.COLUMN_NAME
173        FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
174        WHERE vcu.VIEW_SCHEMA IN ({})
175        "#,
176        placeholders.join(",")
177    );
178
179    let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
180    for schema in schemas {
181        q = q.bind(schema);
182    }
183
184    match q.fetch_all(pool).await {
185        Ok(rows) => Ok(rows
186            .into_iter()
187            .map(
188                |(view_schema, view_name, table_schema, table_name, column_name)| {
189                    ViewColumnSource {
190                        view_schema,
191                        view_name,
192                        table_schema,
193                        table_name,
194                        column_name,
195                    }
196                },
197            )
198            .collect()),
199        Err(_) => {
200            // VIEW_COLUMN_USAGE may not exist on older MySQL versions
201            Ok(Vec::new())
202        }
203    }
204}
205
206fn resolve_view_nullability(
207    views: &mut [TableInfo],
208    sources: &[ViewColumnSource],
209    tables: &[TableInfo],
210) {
211    // Build table column lookup: (schema, table, column) -> is_nullable
212    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
213    for table in tables {
214        for col in &table.columns {
215            table_lookup.insert(
216                (&table.schema_name, &table.name, &col.name),
217                col.is_nullable,
218            );
219        }
220    }
221
222    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_nullable>
223    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
224    for src in sources {
225        if let Some(&is_nullable) =
226            table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
227        {
228            view_lookup
229                .entry((&src.view_schema, &src.view_name, &src.column_name))
230                .or_default()
231                .push(is_nullable);
232        }
233    }
234
235    for view in views.iter_mut() {
236        for col in view.columns.iter_mut() {
237            if let Some(nullable_flags) = view_lookup.get(&(
238                view.schema_name.as_str(),
239                view.name.as_str(),
240                col.name.as_str(),
241            )) {
242                // Only mark as non-nullable if ALL sources are NOT nullable
243                if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
244                    col.is_nullable = false;
245                }
246            }
247        }
248    }
249}
250
251/// Extract inline ENUMs from column types.
252/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
253/// keyed by table_name + column_name.
254fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
255    let mut enums = Vec::new();
256
257    for table in tables {
258        for col in &table.columns {
259            if col.udt_name.starts_with("enum(") {
260                let variants = parse_enum_variants(&col.udt_name);
261                if !variants.is_empty() {
262                    let enum_name = format!("{}_{}", table.name, col.name);
263                    enums.push(EnumInfo {
264                        schema_name: table.schema_name.clone(),
265                        name: enum_name,
266                        variants,
267                        default_variant: None,
268                    });
269                }
270            }
271        }
272    }
273
274    enums
275}
276
277fn parse_enum_variants(column_type: &str) -> Vec<String> {
278    // Parse "enum('a','b','c')" → ["a", "b", "c"]
279    let inner = column_type
280        .strip_prefix("enum(")
281        .and_then(|s| s.strip_suffix(')'));
282    match inner {
283        Some(s) => s
284            .split(',')
285            .map(|v| v.trim().trim_matches('\'').to_string())
286            .filter(|v| !v.is_empty())
287            .collect(),
288        None => Vec::new(),
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use super::*;
295
296    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
297        TableInfo {
298            schema_name: "test_db".to_string(),
299            name: name.to_string(),
300            columns,
301        }
302    }
303
304    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
305        ColumnInfo {
306            name: name.to_string(),
307            data_type: "varchar".to_string(),
308            udt_name: udt_name.to_string(),
309            is_nullable: false,
310            is_primary_key: false,
311            ordinal_position: 0,
312            schema_name: "test_db".to_string(),
313            column_default: None,
314        }
315    }
316
317    // ========== parse_enum_variants ==========
318
319    #[test]
320    fn test_parse_simple() {
321        assert_eq!(
322            parse_enum_variants("enum('a','b','c')"),
323            vec!["a", "b", "c"]
324        );
325    }
326
327    #[test]
328    fn test_parse_single_variant() {
329        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
330    }
331
332    #[test]
333    fn test_parse_with_spaces() {
334        assert_eq!(
335            parse_enum_variants("enum( 'a' , 'b' )"),
336            vec!["a", "b"]
337        );
338    }
339
340    #[test]
341    fn test_parse_empty_parens() {
342        let result = parse_enum_variants("enum()");
343        assert!(result.is_empty());
344    }
345
346    #[test]
347    fn test_parse_varchar_not_enum() {
348        let result = parse_enum_variants("varchar(255)");
349        assert!(result.is_empty());
350    }
351
352    #[test]
353    fn test_parse_int_not_enum() {
354        let result = parse_enum_variants("int");
355        assert!(result.is_empty());
356    }
357
358    #[test]
359    fn test_parse_with_spaces_in_value() {
360        assert_eq!(
361            parse_enum_variants("enum('with space','no')"),
362            vec!["with space", "no"]
363        );
364    }
365
366    #[test]
367    fn test_parse_empty_variant_filtered() {
368        let result = parse_enum_variants("enum('a','','c')");
369        assert_eq!(result, vec!["a", "c"]);
370    }
371
372    #[test]
373    fn test_parse_uppercase_enum_not_matched() {
374        // "ENUM(" doesn't match "enum(" prefix
375        let result = parse_enum_variants("ENUM('a','b')");
376        assert!(result.is_empty());
377    }
378
379    // ========== extract_enums ==========
380
381    #[test]
382    fn test_extract_from_enum_column() {
383        let tables = vec![make_table(
384            "users",
385            vec![make_col("status", "enum('active','inactive')")],
386        )];
387        let enums = extract_enums(&tables);
388        assert_eq!(enums.len(), 1);
389        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
390    }
391
392    #[test]
393    fn test_extract_enum_name_format() {
394        let tables = vec![make_table(
395            "users",
396            vec![make_col("status", "enum('a')")],
397        )];
398        let enums = extract_enums(&tables);
399        assert_eq!(enums[0].name, "users_status");
400    }
401
402    #[test]
403    fn test_extract_no_enums() {
404        let tables = vec![make_table(
405            "users",
406            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
407        )];
408        let enums = extract_enums(&tables);
409        assert!(enums.is_empty());
410    }
411
412    #[test]
413    fn test_extract_two_enum_columns_same_table() {
414        let tables = vec![make_table(
415            "users",
416            vec![
417                make_col("status", "enum('active','inactive')"),
418                make_col("role", "enum('admin','user')"),
419            ],
420        )];
421        let enums = extract_enums(&tables);
422        assert_eq!(enums.len(), 2);
423        assert_eq!(enums[0].name, "users_status");
424        assert_eq!(enums[1].name, "users_role");
425    }
426
427    #[test]
428    fn test_extract_enums_from_multiple_tables() {
429        let tables = vec![
430            make_table("users", vec![make_col("status", "enum('a')")]),
431            make_table("posts", vec![make_col("state", "enum('b')")]),
432        ];
433        let enums = extract_enums(&tables);
434        assert_eq!(enums.len(), 2);
435    }
436
437    #[test]
438    fn test_extract_non_enum_column_ignored() {
439        let tables = vec![make_table(
440            "users",
441            vec![
442                make_col("id", "int(11)"),
443                make_col("status", "enum('a')"),
444            ],
445        )];
446        let enums = extract_enums(&tables);
447        assert_eq!(enums.len(), 1);
448    }
449
450    // ========== resolve_view_nullability ==========
451
452    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
453        TableInfo {
454            schema_name: schema.to_string(),
455            name: name.to_string(),
456            columns: columns
457                .into_iter()
458                .enumerate()
459                .map(|(i, col)| ColumnInfo {
460                    name: col.to_string(),
461                    data_type: "varchar".to_string(),
462                    udt_name: "varchar(255)".to_string(),
463                    is_nullable: true,
464                    is_primary_key: false,
465                    ordinal_position: i as i32,
466                    schema_name: schema.to_string(),
467                    column_default: None,
468                })
469                .collect(),
470        }
471    }
472
473    fn make_table_with_nullability(
474        schema: &str,
475        name: &str,
476        columns: Vec<(&str, bool)>,
477    ) -> TableInfo {
478        TableInfo {
479            schema_name: schema.to_string(),
480            name: name.to_string(),
481            columns: columns
482                .into_iter()
483                .enumerate()
484                .map(|(i, (col, nullable))| ColumnInfo {
485                    name: col.to_string(),
486                    data_type: "varchar".to_string(),
487                    udt_name: "varchar(255)".to_string(),
488                    is_nullable: nullable,
489                    is_primary_key: false,
490                    ordinal_position: i as i32,
491                    schema_name: schema.to_string(),
492                    column_default: None,
493                })
494                .collect(),
495        }
496    }
497
498    fn make_source(
499        view_schema: &str,
500        view_name: &str,
501        table_schema: &str,
502        table_name: &str,
503        column_name: &str,
504    ) -> ViewColumnSource {
505        ViewColumnSource {
506            view_schema: view_schema.to_string(),
507            view_name: view_name.to_string(),
508            table_schema: table_schema.to_string(),
509            table_name: table_name.to_string(),
510            column_name: column_name.to_string(),
511        }
512    }
513
514    #[test]
515    fn test_resolve_not_null_column() {
516        let tables = vec![make_table_with_nullability(
517            "db",
518            "users",
519            vec![("id", false), ("name", false)],
520        )];
521        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
522        let sources = vec![
523            make_source("db", "my_view", "db", "users", "id"),
524            make_source("db", "my_view", "db", "users", "name"),
525        ];
526        resolve_view_nullability(&mut views, &sources, &tables);
527        assert!(!views[0].columns[0].is_nullable);
528        assert!(!views[0].columns[1].is_nullable);
529    }
530
531    #[test]
532    fn test_resolve_nullable_source() {
533        let tables = vec![make_table_with_nullability(
534            "db",
535            "users",
536            vec![("id", false), ("name", true)],
537        )];
538        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
539        let sources = vec![
540            make_source("db", "my_view", "db", "users", "id"),
541            make_source("db", "my_view", "db", "users", "name"),
542        ];
543        resolve_view_nullability(&mut views, &sources, &tables);
544        assert!(!views[0].columns[0].is_nullable);
545        assert!(views[0].columns[1].is_nullable);
546    }
547
548    #[test]
549    fn test_resolve_no_match_stays_nullable() {
550        let tables = vec![make_table_with_nullability(
551            "db",
552            "users",
553            vec![("id", false)],
554        )];
555        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
556        let sources = vec![];
557        resolve_view_nullability(&mut views, &sources, &tables);
558        assert!(views[0].columns[0].is_nullable);
559    }
560
561    #[test]
562    fn test_resolve_empty_sources() {
563        let tables = vec![];
564        let mut views = vec![make_view("db", "my_view", vec!["id"])];
565        resolve_view_nullability(&mut views, &[], &tables);
566        assert!(views[0].columns[0].is_nullable);
567    }
568}