Skip to main content

sqlx_gen/introspect/
mysql.rs

1use crate::error::Result;
2use sqlx::MySqlPool;
3
4use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
5
6pub async fn introspect(
7    pool: &MySqlPool,
8    schemas: &[String],
9    include_views: bool,
10) -> Result<SchemaInfo> {
11    let tables = fetch_tables(pool, schemas).await?;
12    let views = if include_views {
13        fetch_views(pool, schemas).await?
14    } else {
15        Vec::new()
16    };
17    let enums = extract_enums(&tables);
18
19    Ok(SchemaInfo {
20        tables,
21        views,
22        enums,
23        composite_types: Vec::new(),
24        domains: Vec::new(),
25    })
26}
27
28async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
29    // MySQL doesn't support binding arrays directly, so we build placeholders
30    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
31    let query = format!(
32        r#"
33        SELECT
34            c.TABLE_SCHEMA,
35            c.TABLE_NAME,
36            c.COLUMN_NAME,
37            c.DATA_TYPE,
38            c.COLUMN_TYPE,
39            c.IS_NULLABLE,
40            c.ORDINAL_POSITION,
41            c.COLUMN_KEY
42        FROM information_schema.COLUMNS c
43        JOIN information_schema.TABLES t
44            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
45            AND t.TABLE_NAME = c.TABLE_NAME
46            AND t.TABLE_TYPE = 'BASE TABLE'
47        WHERE c.TABLE_SCHEMA IN ({})
48        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
49        "#,
50        placeholders.join(",")
51    );
52
53    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
54    for schema in schemas {
55        q = q.bind(schema);
56    }
57    let rows = q.fetch_all(pool).await?;
58
59    let mut tables: Vec<TableInfo> = Vec::new();
60    let mut current_key: Option<(String, String)> = None;
61
62    for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
63        let key = (schema.clone(), table.clone());
64        if current_key.as_ref() != Some(&key) {
65            current_key = Some(key);
66            tables.push(TableInfo {
67                schema_name: schema.clone(),
68                name: table.clone(),
69                columns: Vec::new(),
70            });
71        }
72        tables.last_mut().unwrap().columns.push(ColumnInfo {
73            name: col_name,
74            data_type,
75            udt_name: column_type,
76            is_nullable: nullable == "YES",
77            is_primary_key: column_key == "PRI",
78            ordinal_position: ordinal as i32,
79            schema_name: schema,
80        });
81    }
82
83    Ok(tables)
84}
85
86async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
87    let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
88    let query = format!(
89        r#"
90        SELECT
91            c.TABLE_SCHEMA,
92            c.TABLE_NAME,
93            c.COLUMN_NAME,
94            c.DATA_TYPE,
95            c.COLUMN_TYPE,
96            c.IS_NULLABLE,
97            c.ORDINAL_POSITION
98        FROM information_schema.COLUMNS c
99        JOIN information_schema.TABLES t
100            ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
101            AND t.TABLE_NAME = c.TABLE_NAME
102            AND t.TABLE_TYPE = 'VIEW'
103        WHERE c.TABLE_SCHEMA IN ({})
104        ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
105        "#,
106        placeholders.join(",")
107    );
108
109    let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
110    for schema in schemas {
111        q = q.bind(schema);
112    }
113    let rows = q.fetch_all(pool).await?;
114
115    let mut views: Vec<TableInfo> = Vec::new();
116    let mut current_key: Option<(String, String)> = None;
117
118    for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
119        let key = (schema.clone(), table.clone());
120        if current_key.as_ref() != Some(&key) {
121            current_key = Some(key);
122            views.push(TableInfo {
123                schema_name: schema.clone(),
124                name: table.clone(),
125                columns: Vec::new(),
126            });
127        }
128        views.last_mut().unwrap().columns.push(ColumnInfo {
129            name: col_name,
130            data_type,
131            udt_name: column_type,
132            is_nullable: nullable == "YES",
133            is_primary_key: false,
134            ordinal_position: ordinal as i32,
135            schema_name: schema,
136        });
137    }
138
139    Ok(views)
140}
141
142/// Extract inline ENUMs from column types.
143/// MySQL ENUM('a','b','c') in COLUMN_TYPE gets extracted to an EnumInfo
144/// keyed by table_name + column_name.
145fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
146    let mut enums = Vec::new();
147
148    for table in tables {
149        for col in &table.columns {
150            if col.udt_name.starts_with("enum(") {
151                let variants = parse_enum_variants(&col.udt_name);
152                if !variants.is_empty() {
153                    let enum_name = format!("{}_{}", table.name, col.name);
154                    enums.push(EnumInfo {
155                        schema_name: table.schema_name.clone(),
156                        name: enum_name,
157                        variants,
158                    });
159                }
160            }
161        }
162    }
163
164    enums
165}
166
167fn parse_enum_variants(column_type: &str) -> Vec<String> {
168    // Parse "enum('a','b','c')" → ["a", "b", "c"]
169    let inner = column_type
170        .strip_prefix("enum(")
171        .and_then(|s| s.strip_suffix(')'));
172    match inner {
173        Some(s) => s
174            .split(',')
175            .map(|v| v.trim().trim_matches('\'').to_string())
176            .filter(|v| !v.is_empty())
177            .collect(),
178        None => Vec::new(),
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
187        TableInfo {
188            schema_name: "test_db".to_string(),
189            name: name.to_string(),
190            columns,
191        }
192    }
193
194    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
195        ColumnInfo {
196            name: name.to_string(),
197            data_type: "varchar".to_string(),
198            udt_name: udt_name.to_string(),
199            is_nullable: false,
200            is_primary_key: false,
201            ordinal_position: 0,
202            schema_name: "test_db".to_string(),
203        }
204    }
205
206    // ========== parse_enum_variants ==========
207
208    #[test]
209    fn test_parse_simple() {
210        assert_eq!(
211            parse_enum_variants("enum('a','b','c')"),
212            vec!["a", "b", "c"]
213        );
214    }
215
216    #[test]
217    fn test_parse_single_variant() {
218        assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
219    }
220
221    #[test]
222    fn test_parse_with_spaces() {
223        assert_eq!(
224            parse_enum_variants("enum( 'a' , 'b' )"),
225            vec!["a", "b"]
226        );
227    }
228
229    #[test]
230    fn test_parse_empty_parens() {
231        let result = parse_enum_variants("enum()");
232        assert!(result.is_empty());
233    }
234
235    #[test]
236    fn test_parse_varchar_not_enum() {
237        let result = parse_enum_variants("varchar(255)");
238        assert!(result.is_empty());
239    }
240
241    #[test]
242    fn test_parse_int_not_enum() {
243        let result = parse_enum_variants("int");
244        assert!(result.is_empty());
245    }
246
247    #[test]
248    fn test_parse_with_spaces_in_value() {
249        assert_eq!(
250            parse_enum_variants("enum('with space','no')"),
251            vec!["with space", "no"]
252        );
253    }
254
255    #[test]
256    fn test_parse_empty_variant_filtered() {
257        let result = parse_enum_variants("enum('a','','c')");
258        assert_eq!(result, vec!["a", "c"]);
259    }
260
261    #[test]
262    fn test_parse_uppercase_enum_not_matched() {
263        // "ENUM(" doesn't match "enum(" prefix
264        let result = parse_enum_variants("ENUM('a','b')");
265        assert!(result.is_empty());
266    }
267
268    // ========== extract_enums ==========
269
270    #[test]
271    fn test_extract_from_enum_column() {
272        let tables = vec![make_table(
273            "users",
274            vec![make_col("status", "enum('active','inactive')")],
275        )];
276        let enums = extract_enums(&tables);
277        assert_eq!(enums.len(), 1);
278        assert_eq!(enums[0].variants, vec!["active", "inactive"]);
279    }
280
281    #[test]
282    fn test_extract_enum_name_format() {
283        let tables = vec![make_table(
284            "users",
285            vec![make_col("status", "enum('a')")],
286        )];
287        let enums = extract_enums(&tables);
288        assert_eq!(enums[0].name, "users_status");
289    }
290
291    #[test]
292    fn test_extract_no_enums() {
293        let tables = vec![make_table(
294            "users",
295            vec![make_col("id", "int"), make_col("name", "varchar(255)")],
296        )];
297        let enums = extract_enums(&tables);
298        assert!(enums.is_empty());
299    }
300
301    #[test]
302    fn test_extract_two_enum_columns_same_table() {
303        let tables = vec![make_table(
304            "users",
305            vec![
306                make_col("status", "enum('active','inactive')"),
307                make_col("role", "enum('admin','user')"),
308            ],
309        )];
310        let enums = extract_enums(&tables);
311        assert_eq!(enums.len(), 2);
312        assert_eq!(enums[0].name, "users_status");
313        assert_eq!(enums[1].name, "users_role");
314    }
315
316    #[test]
317    fn test_extract_enums_from_multiple_tables() {
318        let tables = vec![
319            make_table("users", vec![make_col("status", "enum('a')")]),
320            make_table("posts", vec![make_col("state", "enum('b')")]),
321        ];
322        let enums = extract_enums(&tables);
323        assert_eq!(enums.len(), 2);
324    }
325
326    #[test]
327    fn test_extract_non_enum_column_ignored() {
328        let tables = vec![make_table(
329            "users",
330            vec![
331                make_col("id", "int(11)"),
332                make_col("status", "enum('a')"),
333            ],
334        )];
335        let enums = extract_enums(&tables);
336        assert_eq!(enums.len(), 1);
337    }
338}