Skip to main content

sqlx_gen/introspect/
mysql.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &MySqlPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let sources = fetch_view_column_sources(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &sources, &tables);
23        resolve_view_primary_keys(&mut views, &sources, &tables);
24    }
25
26    let enums = extract_enums(&tables);
27
28    Ok(SchemaInfo {
29        tables,
30        views,
31        enums,
32        composite_types: Vec::new(),
33        domains: Vec::new(),
34    })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38    // MySQL doesn't support binding arrays directly, so we build placeholders
39    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40    let query = format!(
41        r#"
42        SELECT
43            c.TABLE_SCHEMA,
44            c.TABLE_NAME,
45            c.COLUMN_NAME,
46            c.DATA_TYPE,
47            c.COLUMN_TYPE,
48            c.IS_NULLABLE,
49            c.ORDINAL_POSITION,
50            c.COLUMN_KEY
51        FROM information_schema.COLUMNS c
52        JOIN information_schema.TABLES t
53            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54            AND t.TABLE_NAME = c.TABLE_NAME
55            AND t.TABLE_TYPE = 'BASE TABLE'
56        WHERE c.TABLE_SCHEMA IN ({})
57        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58        "#,
59        placeholders.join(",")
60    );
61
62    let mut q = sqlx::query_as::<
63        _,
64        (
65            Vec<u8>,
66            Vec<u8>,
67            Vec<u8>,
68            Vec<u8>,
69            Vec<u8>,
70            Vec<u8>,
71            u32,
72            Vec<u8>,
73        ),
74    >(&query);
75    for schema in schemas {
76        q = q.bind(schema);
77    }
78    let rows = q.fetch_all(pool).await?;
79
80    let mut tables: Vec<TableInfo> = Vec::new();
81    let mut current_key: Option<(String, String)> = None;
82
83    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
84        let schema = utf8_field(schema, "TABLE_SCHEMA")?;
85        let table = utf8_field(table, "TABLE_NAME")?;
86        let col_name = utf8_field(col_name, "COLUMN_NAME")?;
87        let data_type = utf8_field(data_type, "DATA_TYPE")?;
88        let column_type = utf8_field(column_type, "COLUMN_TYPE")?;
89        let nullable = utf8_field(nullable, "IS_NULLABLE")?;
90        let column_key = utf8_field(column_key, "COLUMN_KEY")?;
91
92        let key = (schema.clone(), table.clone());
93        if current_key.as_ref() != Some(&key) {
94            current_key = Some(key);
95            tables.push(TableInfo {
96                schema_name: schema.clone(),
97                name: table.clone(),
98                columns: Vec::new(),
99            });
100        }
101        let last = tables.last_mut().ok_or_else(|| {
102            crate::error::Error::Config(
103                "Internal sqlx-gen bug: tables vector empty after push".to_string(),
104            )
105        })?;
106        last.columns.push(ColumnInfo {
107            name: col_name,
108            data_type,
109            udt_name: column_type,
110            udt_schema: None,
111            is_nullable: nullable == "YES",
112            is_primary_key: column_key == "PRI",
113            ordinal_position: ordinal as i32,
114            schema_name: schema,
115            column_default: None,
116        });
117    }
118
119    Ok(tables)
120}
121
122/// Decode a MySQL `Vec<u8>` metadata field as UTF-8, returning a structured
123/// error instead of panicking if the bytes are invalid.
124fn utf8_field(bytes: Vec<u8>, field: &str) -> Result<String> {
125    String::from_utf8(bytes).map_err(|_| {
126        crate::error::Error::Config(format!(
127            "Database returned non-UTF8 bytes for MySQL information_schema field '{}'. \
128             sqlx-gen requires UTF-8 metadata.",
129            field
130        ))
131    })
132}
133
134async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
135    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
136    let query = format!(
137        r#"
138        SELECT
139            c.TABLE_SCHEMA,
140            c.TABLE_NAME,
141            c.COLUMN_NAME,
142            c.DATA_TYPE,
143            c.COLUMN_TYPE,
144            c.IS_NULLABLE,
145            c.ORDINAL_POSITION
146        FROM information_schema.COLUMNS c
147        JOIN information_schema.TABLES t
148            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
149            AND t.TABLE_NAME = c.TABLE_NAME
150            AND t.TABLE_TYPE = 'VIEW'
151        WHERE c.TABLE_SCHEMA IN ({})
152        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
153        "#,
154        placeholders.join(",")
155    );
156
157    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
158    for schema in schemas {
159        q = q.bind(schema);
160    }
161    let rows = q.fetch_all(pool).await?;
162
163    let mut views: Vec<TableInfo> = Vec::new();
164    let mut current_key: Option<(String, String)> = None;
165
166    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
167        let key = (schema.clone(), table.clone());
168        if current_key.as_ref() != Some(&key) {
169            current_key = Some(key);
170            views.push(TableInfo {
171                schema_name: schema.clone(),
172                name: table.clone(),
173                columns: Vec::new(),
174            });
175        }
176        let last = views.last_mut().ok_or_else(|| {
177            crate::error::Error::Config(
178                "Internal sqlx-gen bug: views vector empty after push".to_string(),
179            )
180        })?;
181        last.columns.push(ColumnInfo {
182            name: col_name,
183            data_type,
184            udt_name: column_type,
185            udt_schema: None,
186            is_nullable: nullable == "YES",
187            is_primary_key: false,
188            ordinal_position: ordinal as i32,
189            schema_name: schema,
190            column_default: None,
191        });
192    }
193
194    Ok(views)
195}
196
197struct ViewColumnSource {
198    view_schema: String,
199    view_name: String,
200    table_schema: String,
201    table_name: String,
202    column_name: String,
203}
204
205async fn fetch_view_column_sources(
206    pool: &MySqlPool,
207    schemas: &[String],
208) -> Result<Vec<ViewColumnSource>> {
209    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
210    let query = format!(
211        r#"
212        SELECT
213            vcu.VIEW_SCHEMA,
214            vcu.VIEW_NAME,
215            vcu.TABLE_SCHEMA,
216            vcu.TABLE_NAME,
217            vcu.COLUMN_NAME
218        FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
219        WHERE vcu.VIEW_SCHEMA IN ({})
220        "#,
221        placeholders.join(",")
222    );
223
224    let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
225    for schema in schemas {
226        q = q.bind(schema);
227    }
228
229    match q.fetch_all(pool).await {
230        Ok(rows) => Ok(rows
231            .into_iter()
232            .map(
233                |(view_schema, view_name, table_schema, table_name, column_name)| {
234                    ViewColumnSource {
235                        view_schema,
236                        view_name,
237                        table_schema,
238                        table_name,
239                        column_name,
240                    }
241                },
242            )
243            .collect()),
244        Err(_) => {
245            // VIEW_COLUMN_USAGE may not exist on older MySQL versions
246            Ok(Vec::new())
247        }
248    }
249}
250
251fn resolve_view_nullability(
252    views: &mut [TableInfo],
253    sources: &[ViewColumnSource],
254    tables: &[TableInfo],
255) {
256    // Build table column lookup: (schema, table, column) -> is_nullable
257    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
258    for table in tables {
259        for col in &table.columns {
260            table_lookup.insert(
261                (&table.schema_name, &table.name, &col.name),
262                col.is_nullable,
263            );
264        }
265    }
266
267    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_nullable>
268    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
269    for src in sources {
270        if let Some(&is_nullable) = table_lookup.get(&(
271            src.table_schema.as_str(),
272            src.table_name.as_str(),
273            src.column_name.as_str(),
274        )) {
275            view_lookup
276                .entry((&src.view_schema, &src.view_name, &src.column_name))
277                .or_default()
278                .push(is_nullable);
279        }
280    }
281
282    for view in views.iter_mut() {
283        for col in view.columns.iter_mut() {
284            if let Some(nullable_flags) = view_lookup.get(&(
285                view.schema_name.as_str(),
286                view.name.as_str(),
287                col.name.as_str(),
288            )) {
289                // Only mark as non-nullable if ALL sources are NOT nullable
290                if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
291                    col.is_nullable = false;
292                }
293            }
294        }
295    }
296}
297
298fn resolve_view_primary_keys(
299    views: &mut [TableInfo],
300    sources: &[ViewColumnSource],
301    tables: &[TableInfo],
302) {
303    // Build table column lookup: (schema, table, column) -> is_primary_key
304    let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
305    for table in tables {
306        for col in &table.columns {
307            table_lookup.insert(
308                (&table.schema_name, &table.name, &col.name),
309                col.is_primary_key,
310            );
311        }
312    }
313
314    // Build view column source lookup: (view_schema, view_name, column_name) -> Vec<is_pk>
315    let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
316    for src in sources {
317        if let Some(&is_pk) = table_lookup.get(&(
318            src.table_schema.as_str(),
319            src.table_name.as_str(),
320            src.column_name.as_str(),
321        )) {
322            view_lookup
323                .entry((&src.view_schema, &src.view_name, &src.column_name))
324                .or_default()
325                .push(is_pk);
326        }
327    }
328
329    for view in views.iter_mut() {
330        for col in view.columns.iter_mut() {
331            if let Some(pk_flags) = view_lookup.get(&(
332                view.schema_name.as_str(),
333                view.name.as_str(),
334                col.name.as_str(),
335            )) {
336                // Only mark as PK if ALL sources are PKs
337                if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
338                    col.is_primary_key = true;
339                }
340            }
341        }
342    }
343}
344
345/// Extract inline ENUMs from column types.
346/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
347/// keyed by table_name + column_name.
348fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
349    let mut enums = Vec::new();
350
351    for table in tables {
352        for col in &table.columns {
353            if col.udt_name.starts_with("enum(") {
354                let variants = parse_enum_variants(&col.udt_name);
355                if !variants.is_empty() {
356                    let enum_name = format!("{}_{}", table.name, col.name);
357                    enums.push(EnumInfo {
358                        schema_name: table.schema_name.clone(),
359                        name: enum_name,
360                        variants,
361                        default_variant: None,
362                    });
363                }
364            }
365        }
366    }
367
368    enums
369}
370
371fn parse_enum_variants(column_type: &str) -> Vec<String> {
372    // Parse "enum('a','b','c')" → ["a", "b", "c"]
373    let inner = column_type
374        .strip_prefix("enum(")
375        .and_then(|s| s.strip_suffix(')'));
376    match inner {
377        Some(s) => s
378            .split(',')
379            .map(|v| v.trim().trim_matches('\'').to_string())
380            .filter(|v| !v.is_empty())
381            .collect(),
382        None => Vec::new(),
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
391        TableInfo {
392            schema_name: "test_db".to_string(),
393            name: name.to_string(),
394            columns,
395        }
396    }
397
398    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
399        ColumnInfo {
400            name: name.to_string(),
401            data_type: "varchar".to_string(),
402            udt_name: udt_name.to_string(),
403            is_nullable: false,
404            is_primary_key: false,
405            ordinal_position: 0,
406            schema_name: "test_db".to_string(),
407            udt_schema: None,
408            column_default: None,
409        }
410    }
411
412    // ========== parse_enum_variants ==========
413
414    #[test]
415    fn test_parse_simple() {
416        assert_eq!(
417            parse_enum_variants("enum('a','b','c')"),
418            vec!["a", "b", "c"]
419        );
420    }
421
422    #[test]
423    fn test_parse_single_variant() {
424        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
425    }
426
427    #[test]
428    fn test_parse_with_spaces() {
429        assert_eq!(parse_enum_variants("enum( 'a' , 'b' )"), vec!["a", "b"]);
430    }
431
432    #[test]
433    fn test_parse_empty_parens() {
434        let result = parse_enum_variants("enum()");
435        assert!(result.is_empty());
436    }
437
438    #[test]
439    fn test_parse_varchar_not_enum() {
440        let result = parse_enum_variants("varchar(255)");
441        assert!(result.is_empty());
442    }
443
444    #[test]
445    fn test_parse_int_not_enum() {
446        let result = parse_enum_variants("int");
447        assert!(result.is_empty());
448    }
449
450    #[test]
451    fn test_parse_with_spaces_in_value() {
452        assert_eq!(
453            parse_enum_variants("enum('with space','no')"),
454            vec!["with space", "no"]
455        );
456    }
457
458    #[test]
459    fn test_parse_empty_variant_filtered() {
460        let result = parse_enum_variants("enum('a','','c')");
461        assert_eq!(result, vec!["a", "c"]);
462    }
463
464    #[test]
465    fn test_parse_uppercase_enum_not_matched() {
466        // "ENUM(" doesn't match "enum(" prefix
467        let result = parse_enum_variants("ENUM('a','b')");
468        assert!(result.is_empty());
469    }
470
471    // ========== extract_enums ==========
472
473    #[test]
474    fn test_extract_from_enum_column() {
475        let tables = vec![make_table(
476            "users",
477            vec![make_col("status", "enum('active','inactive')")],
478        )];
479        let enums = extract_enums(&tables);
480        assert_eq!(enums.len(), 1);
481        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
482    }
483
484    #[test]
485    fn test_extract_enum_name_format() {
486        let tables = vec![make_table("users", vec![make_col("status", "enum('a')")])];
487        let enums = extract_enums(&tables);
488        assert_eq!(enums[0].name, "users_status");
489    }
490
491    #[test]
492    fn test_extract_no_enums() {
493        let tables = vec![make_table(
494            "users",
495            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
496        )];
497        let enums = extract_enums(&tables);
498        assert!(enums.is_empty());
499    }
500
501    #[test]
502    fn test_extract_two_enum_columns_same_table() {
503        let tables = vec![make_table(
504            "users",
505            vec![
506                make_col("status", "enum('active','inactive')"),
507                make_col("role", "enum('admin','user')"),
508            ],
509        )];
510        let enums = extract_enums(&tables);
511        assert_eq!(enums.len(), 2);
512        assert_eq!(enums[0].name, "users_status");
513        assert_eq!(enums[1].name, "users_role");
514    }
515
516    #[test]
517    fn test_extract_enums_from_multiple_tables() {
518        let tables = vec![
519            make_table("users", vec![make_col("status", "enum('a')")]),
520            make_table("posts", vec![make_col("state", "enum('b')")]),
521        ];
522        let enums = extract_enums(&tables);
523        assert_eq!(enums.len(), 2);
524    }
525
526    #[test]
527    fn test_extract_non_enum_column_ignored() {
528        let tables = vec![make_table(
529            "users",
530            vec![make_col("id", "int(11)"), make_col("status", "enum('a')")],
531        )];
532        let enums = extract_enums(&tables);
533        assert_eq!(enums.len(), 1);
534    }
535
536    // ========== resolve_view_nullability ==========
537
538    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
539        TableInfo {
540            schema_name: schema.to_string(),
541            name: name.to_string(),
542            columns: columns
543                .into_iter()
544                .enumerate()
545                .map(|(i, col)| ColumnInfo {
546                    name: col.to_string(),
547                    data_type: "varchar".to_string(),
548                    udt_name: "varchar(255)".to_string(),
549                    is_nullable: true,
550                    is_primary_key: false,
551                    ordinal_position: i as i32,
552                    schema_name: schema.to_string(),
553                    udt_schema: None,
554                    column_default: None,
555                })
556                .collect(),
557        }
558    }
559
560    fn make_table_with_nullability(
561        schema: &str,
562        name: &str,
563        columns: Vec<(&str, bool)>,
564    ) -> TableInfo {
565        TableInfo {
566            schema_name: schema.to_string(),
567            name: name.to_string(),
568            columns: columns
569                .into_iter()
570                .enumerate()
571                .map(|(i, (col, nullable))| ColumnInfo {
572                    name: col.to_string(),
573                    data_type: "varchar".to_string(),
574                    udt_name: "varchar(255)".to_string(),
575                    is_nullable: nullable,
576                    is_primary_key: false,
577                    ordinal_position: i as i32,
578                    schema_name: schema.to_string(),
579                    udt_schema: None,
580                    column_default: None,
581                })
582                .collect(),
583        }
584    }
585
586    fn make_source(
587        view_schema: &str,
588        view_name: &str,
589        table_schema: &str,
590        table_name: &str,
591        column_name: &str,
592    ) -> ViewColumnSource {
593        ViewColumnSource {
594            view_schema: view_schema.to_string(),
595            view_name: view_name.to_string(),
596            table_schema: table_schema.to_string(),
597            table_name: table_name.to_string(),
598            column_name: column_name.to_string(),
599        }
600    }
601
602    #[test]
603    fn test_resolve_not_null_column() {
604        let tables = vec![make_table_with_nullability(
605            "db",
606            "users",
607            vec![("id", false), ("name", false)],
608        )];
609        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
610        let sources = vec![
611            make_source("db", "my_view", "db", "users", "id"),
612            make_source("db", "my_view", "db", "users", "name"),
613        ];
614        resolve_view_nullability(&mut views, &sources, &tables);
615        assert!(!views[0].columns[0].is_nullable);
616        assert!(!views[0].columns[1].is_nullable);
617    }
618
619    #[test]
620    fn test_resolve_nullable_source() {
621        let tables = vec![make_table_with_nullability(
622            "db",
623            "users",
624            vec![("id", false), ("name", true)],
625        )];
626        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
627        let sources = vec![
628            make_source("db", "my_view", "db", "users", "id"),
629            make_source("db", "my_view", "db", "users", "name"),
630        ];
631        resolve_view_nullability(&mut views, &sources, &tables);
632        assert!(!views[0].columns[0].is_nullable);
633        assert!(views[0].columns[1].is_nullable);
634    }
635
636    #[test]
637    fn test_resolve_no_match_stays_nullable() {
638        let tables = vec![make_table_with_nullability(
639            "db",
640            "users",
641            vec![("id", false)],
642        )];
643        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
644        let sources = vec![];
645        resolve_view_nullability(&mut views, &sources, &tables);
646        assert!(views[0].columns[0].is_nullable);
647    }
648
649    #[test]
650    fn test_resolve_empty_sources() {
651        let tables = vec![];
652        let mut views = vec![make_view("db", "my_view", vec!["id"])];
653        resolve_view_nullability(&mut views, &[], &tables);
654        assert!(views[0].columns[0].is_nullable);
655    }
656
657    // ========== resolve_view_primary_keys ==========
658
659    fn make_table_with_pk(schema: &str, name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
660        TableInfo {
661            schema_name: schema.to_string(),
662            name: name.to_string(),
663            columns: columns
664                .into_iter()
665                .enumerate()
666                .map(|(i, (col, is_pk))| ColumnInfo {
667                    name: col.to_string(),
668                    data_type: "varchar".to_string(),
669                    udt_name: "varchar(255)".to_string(),
670                    is_nullable: false,
671                    is_primary_key: is_pk,
672                    ordinal_position: i as i32,
673                    schema_name: schema.to_string(),
674                    udt_schema: None,
675                    column_default: None,
676                })
677                .collect(),
678        }
679    }
680
681    #[test]
682    fn test_resolve_pk_column() {
683        let tables = vec![make_table_with_pk(
684            "db",
685            "users",
686            vec![("id", true), ("name", false)],
687        )];
688        let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
689        let sources = vec![
690            make_source("db", "my_view", "db", "users", "id"),
691            make_source("db", "my_view", "db", "users", "name"),
692        ];
693        resolve_view_primary_keys(&mut views, &sources, &tables);
694        assert!(views[0].columns[0].is_primary_key);
695        assert!(!views[0].columns[1].is_primary_key);
696    }
697
698    #[test]
699    fn test_resolve_pk_no_sources() {
700        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
701        let mut views = vec![make_view("db", "my_view", vec!["id"])];
702        resolve_view_primary_keys(&mut views, &[], &tables);
703        assert!(!views[0].columns[0].is_primary_key);
704    }
705
706    #[test]
707    fn test_resolve_pk_no_match() {
708        let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
709        let mut views = vec![make_view("db", "my_view", vec!["computed"])];
710        let sources = vec![];
711        resolve_view_primary_keys(&mut views, &sources, &tables);
712        assert!(!views[0].columns[0].is_primary_key);
713    }
714}