1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &MySqlPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let sources = fetch_view_column_sources(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &sources, &tables);
23 resolve_view_primary_keys(&mut views, &sources, &tables);
24 }
25
26 let enums = extract_enums(&tables);
27
28 Ok(SchemaInfo {
29 tables,
30 views,
31 enums,
32 composite_types: Vec::new(),
33 domains: Vec::new(),
34 })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40 let query = format!(
41 r#"
42 SELECT
43 c.TABLE_SCHEMA,
44 c.TABLE_NAME,
45 c.COLUMN_NAME,
46 c.DATA_TYPE,
47 c.COLUMN_TYPE,
48 c.IS_NULLABLE,
49 c.ORDINAL_POSITION,
50 c.COLUMN_KEY
51 FROM information_schema.COLUMNS c
52 JOIN information_schema.TABLES t
53 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54 AND t.TABLE_NAME = c.TABLE_NAME
55 AND t.TABLE_TYPE = 'BASE TABLE'
56 WHERE c.TABLE_SCHEMA IN ({})
57 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58 "#,
59 placeholders.join(",")
60 );
61
62 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32, String)>(&query);
63 for schema in schemas {
64 q = q.bind(schema);
65 }
66 let rows = q.fetch_all(pool).await?;
67
68 let mut tables: Vec<TableInfo> = Vec::new();
69 let mut current_key: Option<(String, String)> = None;
70
71 for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
72 let key = (schema.clone(), table.clone());
73 if current_key.as_ref() != Some(&key) {
74 current_key = Some(key);
75 tables.push(TableInfo {
76 schema_name: schema.clone(),
77 name: table.clone(),
78 columns: Vec::new(),
79 });
80 }
81 tables.last_mut().unwrap().columns.push(ColumnInfo {
82 name: col_name,
83 data_type,
84 udt_name: column_type,
85 is_nullable: nullable == "YES",
86 is_primary_key: column_key == "PRI",
87 ordinal_position: ordinal as i32,
88 schema_name: schema,
89 column_default: None,
90 });
91 }
92
93 Ok(tables)
94}
95
96async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
97 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
98 let query = format!(
99 r#"
100 SELECT
101 c.TABLE_SCHEMA,
102 c.TABLE_NAME,
103 c.COLUMN_NAME,
104 c.DATA_TYPE,
105 c.COLUMN_TYPE,
106 c.IS_NULLABLE,
107 c.ORDINAL_POSITION
108 FROM information_schema.COLUMNS c
109 JOIN information_schema.TABLES t
110 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
111 AND t.TABLE_NAME = c.TABLE_NAME
112 AND t.TABLE_TYPE = 'VIEW'
113 WHERE c.TABLE_SCHEMA IN ({})
114 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
115 "#,
116 placeholders.join(",")
117 );
118
119 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
120 for schema in schemas {
121 q = q.bind(schema);
122 }
123 let rows = q.fetch_all(pool).await?;
124
125 let mut views: Vec<TableInfo> = Vec::new();
126 let mut current_key: Option<(String, String)> = None;
127
128 for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
129 let key = (schema.clone(), table.clone());
130 if current_key.as_ref() != Some(&key) {
131 current_key = Some(key);
132 views.push(TableInfo {
133 schema_name: schema.clone(),
134 name: table.clone(),
135 columns: Vec::new(),
136 });
137 }
138 views.last_mut().unwrap().columns.push(ColumnInfo {
139 name: col_name,
140 data_type,
141 udt_name: column_type,
142 is_nullable: nullable == "YES",
143 is_primary_key: false,
144 ordinal_position: ordinal as i32,
145 schema_name: schema,
146 column_default: None,
147 });
148 }
149
150 Ok(views)
151}
152
153struct ViewColumnSource {
154 view_schema: String,
155 view_name: String,
156 table_schema: String,
157 table_name: String,
158 column_name: String,
159}
160
161async fn fetch_view_column_sources(
162 pool: &MySqlPool,
163 schemas: &[String],
164) -> Result<Vec<ViewColumnSource>> {
165 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
166 let query = format!(
167 r#"
168 SELECT
169 vcu.VIEW_SCHEMA,
170 vcu.VIEW_NAME,
171 vcu.TABLE_SCHEMA,
172 vcu.TABLE_NAME,
173 vcu.COLUMN_NAME
174 FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
175 WHERE vcu.VIEW_SCHEMA IN ({})
176 "#,
177 placeholders.join(",")
178 );
179
180 let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
181 for schema in schemas {
182 q = q.bind(schema);
183 }
184
185 match q.fetch_all(pool).await {
186 Ok(rows) => Ok(rows
187 .into_iter()
188 .map(
189 |(view_schema, view_name, table_schema, table_name, column_name)| {
190 ViewColumnSource {
191 view_schema,
192 view_name,
193 table_schema,
194 table_name,
195 column_name,
196 }
197 },
198 )
199 .collect()),
200 Err(_) => {
201 Ok(Vec::new())
203 }
204 }
205}
206
207fn resolve_view_nullability(
208 views: &mut [TableInfo],
209 sources: &[ViewColumnSource],
210 tables: &[TableInfo],
211) {
212 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
214 for table in tables {
215 for col in &table.columns {
216 table_lookup.insert(
217 (&table.schema_name, &table.name, &col.name),
218 col.is_nullable,
219 );
220 }
221 }
222
223 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
225 for src in sources {
226 if let Some(&is_nullable) =
227 table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
228 {
229 view_lookup
230 .entry((&src.view_schema, &src.view_name, &src.column_name))
231 .or_default()
232 .push(is_nullable);
233 }
234 }
235
236 for view in views.iter_mut() {
237 for col in view.columns.iter_mut() {
238 if let Some(nullable_flags) = view_lookup.get(&(
239 view.schema_name.as_str(),
240 view.name.as_str(),
241 col.name.as_str(),
242 )) {
243 if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
245 col.is_nullable = false;
246 }
247 }
248 }
249 }
250}
251
252fn resolve_view_primary_keys(
253 views: &mut [TableInfo],
254 sources: &[ViewColumnSource],
255 tables: &[TableInfo],
256) {
257 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
259 for table in tables {
260 for col in &table.columns {
261 table_lookup.insert(
262 (&table.schema_name, &table.name, &col.name),
263 col.is_primary_key,
264 );
265 }
266 }
267
268 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
270 for src in sources {
271 if let Some(&is_pk) =
272 table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
273 {
274 view_lookup
275 .entry((&src.view_schema, &src.view_name, &src.column_name))
276 .or_default()
277 .push(is_pk);
278 }
279 }
280
281 for view in views.iter_mut() {
282 for col in view.columns.iter_mut() {
283 if let Some(pk_flags) = view_lookup.get(&(
284 view.schema_name.as_str(),
285 view.name.as_str(),
286 col.name.as_str(),
287 )) {
288 if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
290 col.is_primary_key = true;
291 }
292 }
293 }
294 }
295}
296
297fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
301 let mut enums = Vec::new();
302
303 for table in tables {
304 for col in &table.columns {
305 if col.udt_name.starts_with("enum(") {
306 let variants = parse_enum_variants(&col.udt_name);
307 if !variants.is_empty() {
308 let enum_name = format!("{}_{}", table.name, col.name);
309 enums.push(EnumInfo {
310 schema_name: table.schema_name.clone(),
311 name: enum_name,
312 variants,
313 default_variant: None,
314 });
315 }
316 }
317 }
318 }
319
320 enums
321}
322
323fn parse_enum_variants(column_type: &str) -> Vec<String> {
324 let inner = column_type
326 .strip_prefix("enum(")
327 .and_then(|s| s.strip_suffix(')'));
328 match inner {
329 Some(s) => s
330 .split(',')
331 .map(|v| v.trim().trim_matches('\'').to_string())
332 .filter(|v| !v.is_empty())
333 .collect(),
334 None => Vec::new(),
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
343 TableInfo {
344 schema_name: "test_db".to_string(),
345 name: name.to_string(),
346 columns,
347 }
348 }
349
350 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
351 ColumnInfo {
352 name: name.to_string(),
353 data_type: "varchar".to_string(),
354 udt_name: udt_name.to_string(),
355 is_nullable: false,
356 is_primary_key: false,
357 ordinal_position: 0,
358 schema_name: "test_db".to_string(),
359 column_default: None,
360 }
361 }
362
363 #[test]
366 fn test_parse_simple() {
367 assert_eq!(
368 parse_enum_variants("enum('a','b','c')"),
369 vec!["a", "b", "c"]
370 );
371 }
372
373 #[test]
374 fn test_parse_single_variant() {
375 assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
376 }
377
378 #[test]
379 fn test_parse_with_spaces() {
380 assert_eq!(
381 parse_enum_variants("enum( 'a' , 'b' )"),
382 vec!["a", "b"]
383 );
384 }
385
386 #[test]
387 fn test_parse_empty_parens() {
388 let result = parse_enum_variants("enum()");
389 assert!(result.is_empty());
390 }
391
392 #[test]
393 fn test_parse_varchar_not_enum() {
394 let result = parse_enum_variants("varchar(255)");
395 assert!(result.is_empty());
396 }
397
398 #[test]
399 fn test_parse_int_not_enum() {
400 let result = parse_enum_variants("int");
401 assert!(result.is_empty());
402 }
403
404 #[test]
405 fn test_parse_with_spaces_in_value() {
406 assert_eq!(
407 parse_enum_variants("enum('with space','no')"),
408 vec!["with space", "no"]
409 );
410 }
411
412 #[test]
413 fn test_parse_empty_variant_filtered() {
414 let result = parse_enum_variants("enum('a','','c')");
415 assert_eq!(result, vec!["a", "c"]);
416 }
417
418 #[test]
419 fn test_parse_uppercase_enum_not_matched() {
420 let result = parse_enum_variants("ENUM('a','b')");
422 assert!(result.is_empty());
423 }
424
425 #[test]
428 fn test_extract_from_enum_column() {
429 let tables = vec![make_table(
430 "users",
431 vec![make_col("status", "enum('active','inactive')")],
432 )];
433 let enums = extract_enums(&tables);
434 assert_eq!(enums.len(), 1);
435 assert_eq!(enums[0].variants, vec!["active", "inactive"]);
436 }
437
438 #[test]
439 fn test_extract_enum_name_format() {
440 let tables = vec![make_table(
441 "users",
442 vec![make_col("status", "enum('a')")],
443 )];
444 let enums = extract_enums(&tables);
445 assert_eq!(enums[0].name, "users_status");
446 }
447
448 #[test]
449 fn test_extract_no_enums() {
450 let tables = vec![make_table(
451 "users",
452 vec![make_col("id", "int"), make_col("name", "varchar(255)")],
453 )];
454 let enums = extract_enums(&tables);
455 assert!(enums.is_empty());
456 }
457
458 #[test]
459 fn test_extract_two_enum_columns_same_table() {
460 let tables = vec![make_table(
461 "users",
462 vec![
463 make_col("status", "enum('active','inactive')"),
464 make_col("role", "enum('admin','user')"),
465 ],
466 )];
467 let enums = extract_enums(&tables);
468 assert_eq!(enums.len(), 2);
469 assert_eq!(enums[0].name, "users_status");
470 assert_eq!(enums[1].name, "users_role");
471 }
472
473 #[test]
474 fn test_extract_enums_from_multiple_tables() {
475 let tables = vec![
476 make_table("users", vec![make_col("status", "enum('a')")]),
477 make_table("posts", vec![make_col("state", "enum('b')")]),
478 ];
479 let enums = extract_enums(&tables);
480 assert_eq!(enums.len(), 2);
481 }
482
483 #[test]
484 fn test_extract_non_enum_column_ignored() {
485 let tables = vec![make_table(
486 "users",
487 vec![
488 make_col("id", "int(11)"),
489 make_col("status", "enum('a')"),
490 ],
491 )];
492 let enums = extract_enums(&tables);
493 assert_eq!(enums.len(), 1);
494 }
495
496 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
499 TableInfo {
500 schema_name: schema.to_string(),
501 name: name.to_string(),
502 columns: columns
503 .into_iter()
504 .enumerate()
505 .map(|(i, col)| ColumnInfo {
506 name: col.to_string(),
507 data_type: "varchar".to_string(),
508 udt_name: "varchar(255)".to_string(),
509 is_nullable: true,
510 is_primary_key: false,
511 ordinal_position: i as i32,
512 schema_name: schema.to_string(),
513 column_default: None,
514 })
515 .collect(),
516 }
517 }
518
519 fn make_table_with_nullability(
520 schema: &str,
521 name: &str,
522 columns: Vec<(&str, bool)>,
523 ) -> TableInfo {
524 TableInfo {
525 schema_name: schema.to_string(),
526 name: name.to_string(),
527 columns: columns
528 .into_iter()
529 .enumerate()
530 .map(|(i, (col, nullable))| ColumnInfo {
531 name: col.to_string(),
532 data_type: "varchar".to_string(),
533 udt_name: "varchar(255)".to_string(),
534 is_nullable: nullable,
535 is_primary_key: false,
536 ordinal_position: i as i32,
537 schema_name: schema.to_string(),
538 column_default: None,
539 })
540 .collect(),
541 }
542 }
543
544 fn make_source(
545 view_schema: &str,
546 view_name: &str,
547 table_schema: &str,
548 table_name: &str,
549 column_name: &str,
550 ) -> ViewColumnSource {
551 ViewColumnSource {
552 view_schema: view_schema.to_string(),
553 view_name: view_name.to_string(),
554 table_schema: table_schema.to_string(),
555 table_name: table_name.to_string(),
556 column_name: column_name.to_string(),
557 }
558 }
559
560 #[test]
561 fn test_resolve_not_null_column() {
562 let tables = vec![make_table_with_nullability(
563 "db",
564 "users",
565 vec![("id", false), ("name", false)],
566 )];
567 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
568 let sources = vec![
569 make_source("db", "my_view", "db", "users", "id"),
570 make_source("db", "my_view", "db", "users", "name"),
571 ];
572 resolve_view_nullability(&mut views, &sources, &tables);
573 assert!(!views[0].columns[0].is_nullable);
574 assert!(!views[0].columns[1].is_nullable);
575 }
576
577 #[test]
578 fn test_resolve_nullable_source() {
579 let tables = vec![make_table_with_nullability(
580 "db",
581 "users",
582 vec![("id", false), ("name", true)],
583 )];
584 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
585 let sources = vec![
586 make_source("db", "my_view", "db", "users", "id"),
587 make_source("db", "my_view", "db", "users", "name"),
588 ];
589 resolve_view_nullability(&mut views, &sources, &tables);
590 assert!(!views[0].columns[0].is_nullable);
591 assert!(views[0].columns[1].is_nullable);
592 }
593
594 #[test]
595 fn test_resolve_no_match_stays_nullable() {
596 let tables = vec![make_table_with_nullability(
597 "db",
598 "users",
599 vec![("id", false)],
600 )];
601 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
602 let sources = vec![];
603 resolve_view_nullability(&mut views, &sources, &tables);
604 assert!(views[0].columns[0].is_nullable);
605 }
606
607 #[test]
608 fn test_resolve_empty_sources() {
609 let tables = vec![];
610 let mut views = vec![make_view("db", "my_view", vec!["id"])];
611 resolve_view_nullability(&mut views, &[], &tables);
612 assert!(views[0].columns[0].is_nullable);
613 }
614
615 fn make_table_with_pk(
618 schema: &str,
619 name: &str,
620 columns: Vec<(&str, bool)>,
621 ) -> TableInfo {
622 TableInfo {
623 schema_name: schema.to_string(),
624 name: name.to_string(),
625 columns: columns
626 .into_iter()
627 .enumerate()
628 .map(|(i, (col, is_pk))| ColumnInfo {
629 name: col.to_string(),
630 data_type: "varchar".to_string(),
631 udt_name: "varchar(255)".to_string(),
632 is_nullable: false,
633 is_primary_key: is_pk,
634 ordinal_position: i as i32,
635 schema_name: schema.to_string(),
636 column_default: None,
637 })
638 .collect(),
639 }
640 }
641
642 #[test]
643 fn test_resolve_pk_column() {
644 let tables = vec![make_table_with_pk("db", "users", vec![("id", true), ("name", false)])];
645 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
646 let sources = vec![
647 make_source("db", "my_view", "db", "users", "id"),
648 make_source("db", "my_view", "db", "users", "name"),
649 ];
650 resolve_view_primary_keys(&mut views, &sources, &tables);
651 assert!(views[0].columns[0].is_primary_key);
652 assert!(!views[0].columns[1].is_primary_key);
653 }
654
655 #[test]
656 fn test_resolve_pk_no_sources() {
657 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
658 let mut views = vec![make_view("db", "my_view", vec!["id"])];
659 resolve_view_primary_keys(&mut views, &[], &tables);
660 assert!(!views[0].columns[0].is_primary_key);
661 }
662
663 #[test]
664 fn test_resolve_pk_no_match() {
665 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
666 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
667 let sources = vec![];
668 resolve_view_primary_keys(&mut views, &sources, &tables);
669 assert!(!views[0].columns[0].is_primary_key);
670 }
671}