1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::MySqlPool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &MySqlPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let sources = fetch_view_column_sources(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &sources, &tables);
23 resolve_view_primary_keys(&mut views, &sources, &tables);
24 }
25
26 let enums = extract_enums(&tables);
27
28 Ok(SchemaInfo {
29 tables,
30 views,
31 enums,
32 composite_types: Vec::new(),
33 domains: Vec::new(),
34 })
35}
36
37async fn fetch_tables(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
38 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
40 let query = format!(
41 r#"
42 SELECT
43 c.TABLE_SCHEMA,
44 c.TABLE_NAME,
45 c.COLUMN_NAME,
46 c.DATA_TYPE,
47 c.COLUMN_TYPE,
48 c.IS_NULLABLE,
49 c.ORDINAL_POSITION,
50 c.COLUMN_KEY
51 FROM information_schema.COLUMNS c
52 JOIN information_schema.TABLES t
53 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
54 AND t.TABLE_NAME = c.TABLE_NAME
55 AND t.TABLE_TYPE = 'BASE TABLE'
56 WHERE c.TABLE_SCHEMA IN ({})
57 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
58 "#,
59 placeholders.join(",")
60 );
61
62 let mut q = sqlx::query_as::<_, (Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, Vec<u8>, u32, Vec<u8>)>(&query);
63 for schema in schemas {
64 q = q.bind(schema);
65 }
66 let rows = q.fetch_all(pool).await?;
67
68 let mut tables: Vec<TableInfo> = Vec::new();
69 let mut current_key: Option<(String, String)> = None;
70
71 for (schema, table, col_name, data_type, column_type, nullable, ordinal, column_key) in rows {
72 let schema = String::from_utf8(schema).expect("Could not convert schema name from UTF8 bytes");
73 let table = String::from_utf8(table).expect("Could not convert schema name from UTF8 bytes");
74 let col_name = String::from_utf8(col_name).expect("Could not convert col_name name from UTF8 bytes");
75 let data_type = String::from_utf8(data_type).expect("Could not convert data_type name from UTF8 bytes");
76 let column_type = String::from_utf8(column_type).expect("Could not convert column_type name from UTF8 bytes");
77 let nullable = String::from_utf8(nullable).expect("Could not convert nullable name from UTF8 bytes");
78 let column_key = String::from_utf8(column_key).expect("Could not convert column_key name from UTF8 bytes");
79
80 let key = (schema.clone(), table.clone());
81 if current_key.as_ref() != Some(&key) {
82 current_key = Some(key);
83 tables.push(TableInfo {
84 schema_name: schema.clone(),
85 name: table.clone(),
86 columns: Vec::new(),
87 });
88 }
89 tables.last_mut().unwrap().columns.push(ColumnInfo {
90 name: col_name,
91 data_type,
92 udt_name: column_type,
93 is_nullable: nullable == "YES",
94 is_primary_key: column_key == "PRI",
95 ordinal_position: ordinal as i32,
96 schema_name: schema,
97 column_default: None,
98 });
99 }
100
101 Ok(tables)
102}
103
104async fn fetch_views(pool: &MySqlPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
105 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
106 let query = format!(
107 r#"
108 SELECT
109 c.TABLE_SCHEMA,
110 c.TABLE_NAME,
111 c.COLUMN_NAME,
112 c.DATA_TYPE,
113 c.COLUMN_TYPE,
114 c.IS_NULLABLE,
115 c.ORDINAL_POSITION
116 FROM information_schema.COLUMNS c
117 JOIN information_schema.TABLES t
118 ON t.TABLE_SCHEMA = c.TABLE_SCHEMA
119 AND t.TABLE_NAME = c.TABLE_NAME
120 AND t.TABLE_TYPE = 'VIEW'
121 WHERE c.TABLE_SCHEMA IN ({})
122 ORDER BY c.TABLE_SCHEMA, c.TABLE_NAME, c.ORDINAL_POSITION
123 "#,
124 placeholders.join(",")
125 );
126
127 let mut q = sqlx::query_as::<_, (String, String, String, String, String, String, u32)>(&query);
128 for schema in schemas {
129 q = q.bind(schema);
130 }
131 let rows = q.fetch_all(pool).await?;
132
133 let mut views: Vec<TableInfo> = Vec::new();
134 let mut current_key: Option<(String, String)> = None;
135
136 for (schema, table, col_name, data_type, column_type, nullable, ordinal) in rows {
137 let key = (schema.clone(), table.clone());
138 if current_key.as_ref() != Some(&key) {
139 current_key = Some(key);
140 views.push(TableInfo {
141 schema_name: schema.clone(),
142 name: table.clone(),
143 columns: Vec::new(),
144 });
145 }
146 views.last_mut().unwrap().columns.push(ColumnInfo {
147 name: col_name,
148 data_type,
149 udt_name: column_type,
150 is_nullable: nullable == "YES",
151 is_primary_key: false,
152 ordinal_position: ordinal as i32,
153 schema_name: schema,
154 column_default: None,
155 });
156 }
157
158 Ok(views)
159}
160
161struct ViewColumnSource {
162 view_schema: String,
163 view_name: String,
164 table_schema: String,
165 table_name: String,
166 column_name: String,
167}
168
169async fn fetch_view_column_sources(
170 pool: &MySqlPool,
171 schemas: &[String],
172) -> Result<Vec<ViewColumnSource>> {
173 let placeholders: Vec<String> = (0..schemas.len()).map(|_| "?".to_string()).collect();
174 let query = format!(
175 r#"
176 SELECT
177 vcu.VIEW_SCHEMA,
178 vcu.VIEW_NAME,
179 vcu.TABLE_SCHEMA,
180 vcu.TABLE_NAME,
181 vcu.COLUMN_NAME
182 FROM INFORMATION_SCHEMA.VIEW_COLUMN_USAGE vcu
183 WHERE vcu.VIEW_SCHEMA IN ({})
184 "#,
185 placeholders.join(",")
186 );
187
188 let mut q = sqlx::query_as::<_, (String, String, String, String, String)>(&query);
189 for schema in schemas {
190 q = q.bind(schema);
191 }
192
193 match q.fetch_all(pool).await {
194 Ok(rows) => Ok(rows
195 .into_iter()
196 .map(
197 |(view_schema, view_name, table_schema, table_name, column_name)| {
198 ViewColumnSource {
199 view_schema,
200 view_name,
201 table_schema,
202 table_name,
203 column_name,
204 }
205 },
206 )
207 .collect()),
208 Err(_) => {
209 Ok(Vec::new())
211 }
212 }
213}
214
215fn resolve_view_nullability(
216 views: &mut [TableInfo],
217 sources: &[ViewColumnSource],
218 tables: &[TableInfo],
219) {
220 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
222 for table in tables {
223 for col in &table.columns {
224 table_lookup.insert(
225 (&table.schema_name, &table.name, &col.name),
226 col.is_nullable,
227 );
228 }
229 }
230
231 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
233 for src in sources {
234 if let Some(&is_nullable) =
235 table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
236 {
237 view_lookup
238 .entry((&src.view_schema, &src.view_name, &src.column_name))
239 .or_default()
240 .push(is_nullable);
241 }
242 }
243
244 for view in views.iter_mut() {
245 for col in view.columns.iter_mut() {
246 if let Some(nullable_flags) = view_lookup.get(&(
247 view.schema_name.as_str(),
248 view.name.as_str(),
249 col.name.as_str(),
250 )) {
251 if !nullable_flags.is_empty() && nullable_flags.iter().all(|&n| !n) {
253 col.is_nullable = false;
254 }
255 }
256 }
257 }
258}
259
260fn resolve_view_primary_keys(
261 views: &mut [TableInfo],
262 sources: &[ViewColumnSource],
263 tables: &[TableInfo],
264) {
265 let mut table_lookup: HashMap<(&str, &str, &str), bool> = HashMap::new();
267 for table in tables {
268 for col in &table.columns {
269 table_lookup.insert(
270 (&table.schema_name, &table.name, &col.name),
271 col.is_primary_key,
272 );
273 }
274 }
275
276 let mut view_lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
278 for src in sources {
279 if let Some(&is_pk) =
280 table_lookup.get(&(src.table_schema.as_str(), src.table_name.as_str(), src.column_name.as_str()))
281 {
282 view_lookup
283 .entry((&src.view_schema, &src.view_name, &src.column_name))
284 .or_default()
285 .push(is_pk);
286 }
287 }
288
289 for view in views.iter_mut() {
290 for col in view.columns.iter_mut() {
291 if let Some(pk_flags) = view_lookup.get(&(
292 view.schema_name.as_str(),
293 view.name.as_str(),
294 col.name.as_str(),
295 )) {
296 if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
298 col.is_primary_key = true;
299 }
300 }
301 }
302 }
303}
304
305fn extract_enums(tables: &[TableInfo]) -> Vec<EnumInfo> {
309 let mut enums = Vec::new();
310
311 for table in tables {
312 for col in &table.columns {
313 if col.udt_name.starts_with("enum(") {
314 let variants = parse_enum_variants(&col.udt_name);
315 if !variants.is_empty() {
316 let enum_name = format!("{}_{}", table.name, col.name);
317 enums.push(EnumInfo {
318 schema_name: table.schema_name.clone(),
319 name: enum_name,
320 variants,
321 default_variant: None,
322 });
323 }
324 }
325 }
326 }
327
328 enums
329}
330
331fn parse_enum_variants(column_type: &str) -> Vec<String> {
332 let inner = column_type
334 .strip_prefix("enum(")
335 .and_then(|s| s.strip_suffix(')'));
336 match inner {
337 Some(s) => s
338 .split(',')
339 .map(|v| v.trim().trim_matches('\'').to_string())
340 .filter(|v| !v.is_empty())
341 .collect(),
342 None => Vec::new(),
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
351 TableInfo {
352 schema_name: "test_db".to_string(),
353 name: name.to_string(),
354 columns,
355 }
356 }
357
358 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
359 ColumnInfo {
360 name: name.to_string(),
361 data_type: "varchar".to_string(),
362 udt_name: udt_name.to_string(),
363 is_nullable: false,
364 is_primary_key: false,
365 ordinal_position: 0,
366 schema_name: "test_db".to_string(),
367 column_default: None,
368 }
369 }
370
371 #[test]
374 fn test_parse_simple() {
375 assert_eq!(
376 parse_enum_variants("enum('a','b','c')"),
377 vec!["a", "b", "c"]
378 );
379 }
380
381 #[test]
382 fn test_parse_single_variant() {
383 assert_eq!(parse_enum_variants("enum('only')"), vec!["only"]);
384 }
385
386 #[test]
387 fn test_parse_with_spaces() {
388 assert_eq!(
389 parse_enum_variants("enum( 'a' , 'b' )"),
390 vec!["a", "b"]
391 );
392 }
393
394 #[test]
395 fn test_parse_empty_parens() {
396 let result = parse_enum_variants("enum()");
397 assert!(result.is_empty());
398 }
399
400 #[test]
401 fn test_parse_varchar_not_enum() {
402 let result = parse_enum_variants("varchar(255)");
403 assert!(result.is_empty());
404 }
405
406 #[test]
407 fn test_parse_int_not_enum() {
408 let result = parse_enum_variants("int");
409 assert!(result.is_empty());
410 }
411
412 #[test]
413 fn test_parse_with_spaces_in_value() {
414 assert_eq!(
415 parse_enum_variants("enum('with space','no')"),
416 vec!["with space", "no"]
417 );
418 }
419
420 #[test]
421 fn test_parse_empty_variant_filtered() {
422 let result = parse_enum_variants("enum('a','','c')");
423 assert_eq!(result, vec!["a", "c"]);
424 }
425
426 #[test]
427 fn test_parse_uppercase_enum_not_matched() {
428 let result = parse_enum_variants("ENUM('a','b')");
430 assert!(result.is_empty());
431 }
432
433 #[test]
436 fn test_extract_from_enum_column() {
437 let tables = vec![make_table(
438 "users",
439 vec![make_col("status", "enum('active','inactive')")],
440 )];
441 let enums = extract_enums(&tables);
442 assert_eq!(enums.len(), 1);
443 assert_eq!(enums[0].variants, vec!["active", "inactive"]);
444 }
445
446 #[test]
447 fn test_extract_enum_name_format() {
448 let tables = vec![make_table(
449 "users",
450 vec![make_col("status", "enum('a')")],
451 )];
452 let enums = extract_enums(&tables);
453 assert_eq!(enums[0].name, "users_status");
454 }
455
456 #[test]
457 fn test_extract_no_enums() {
458 let tables = vec![make_table(
459 "users",
460 vec![make_col("id", "int"), make_col("name", "varchar(255)")],
461 )];
462 let enums = extract_enums(&tables);
463 assert!(enums.is_empty());
464 }
465
466 #[test]
467 fn test_extract_two_enum_columns_same_table() {
468 let tables = vec![make_table(
469 "users",
470 vec![
471 make_col("status", "enum('active','inactive')"),
472 make_col("role", "enum('admin','user')"),
473 ],
474 )];
475 let enums = extract_enums(&tables);
476 assert_eq!(enums.len(), 2);
477 assert_eq!(enums[0].name, "users_status");
478 assert_eq!(enums[1].name, "users_role");
479 }
480
481 #[test]
482 fn test_extract_enums_from_multiple_tables() {
483 let tables = vec![
484 make_table("users", vec![make_col("status", "enum('a')")]),
485 make_table("posts", vec![make_col("state", "enum('b')")]),
486 ];
487 let enums = extract_enums(&tables);
488 assert_eq!(enums.len(), 2);
489 }
490
491 #[test]
492 fn test_extract_non_enum_column_ignored() {
493 let tables = vec![make_table(
494 "users",
495 vec![
496 make_col("id", "int(11)"),
497 make_col("status", "enum('a')"),
498 ],
499 )];
500 let enums = extract_enums(&tables);
501 assert_eq!(enums.len(), 1);
502 }
503
504 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
507 TableInfo {
508 schema_name: schema.to_string(),
509 name: name.to_string(),
510 columns: columns
511 .into_iter()
512 .enumerate()
513 .map(|(i, col)| ColumnInfo {
514 name: col.to_string(),
515 data_type: "varchar".to_string(),
516 udt_name: "varchar(255)".to_string(),
517 is_nullable: true,
518 is_primary_key: false,
519 ordinal_position: i as i32,
520 schema_name: schema.to_string(),
521 column_default: None,
522 })
523 .collect(),
524 }
525 }
526
527 fn make_table_with_nullability(
528 schema: &str,
529 name: &str,
530 columns: Vec<(&str, bool)>,
531 ) -> TableInfo {
532 TableInfo {
533 schema_name: schema.to_string(),
534 name: name.to_string(),
535 columns: columns
536 .into_iter()
537 .enumerate()
538 .map(|(i, (col, nullable))| ColumnInfo {
539 name: col.to_string(),
540 data_type: "varchar".to_string(),
541 udt_name: "varchar(255)".to_string(),
542 is_nullable: nullable,
543 is_primary_key: false,
544 ordinal_position: i as i32,
545 schema_name: schema.to_string(),
546 column_default: None,
547 })
548 .collect(),
549 }
550 }
551
552 fn make_source(
553 view_schema: &str,
554 view_name: &str,
555 table_schema: &str,
556 table_name: &str,
557 column_name: &str,
558 ) -> ViewColumnSource {
559 ViewColumnSource {
560 view_schema: view_schema.to_string(),
561 view_name: view_name.to_string(),
562 table_schema: table_schema.to_string(),
563 table_name: table_name.to_string(),
564 column_name: column_name.to_string(),
565 }
566 }
567
568 #[test]
569 fn test_resolve_not_null_column() {
570 let tables = vec![make_table_with_nullability(
571 "db",
572 "users",
573 vec![("id", false), ("name", false)],
574 )];
575 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
576 let sources = vec![
577 make_source("db", "my_view", "db", "users", "id"),
578 make_source("db", "my_view", "db", "users", "name"),
579 ];
580 resolve_view_nullability(&mut views, &sources, &tables);
581 assert!(!views[0].columns[0].is_nullable);
582 assert!(!views[0].columns[1].is_nullable);
583 }
584
585 #[test]
586 fn test_resolve_nullable_source() {
587 let tables = vec![make_table_with_nullability(
588 "db",
589 "users",
590 vec![("id", false), ("name", true)],
591 )];
592 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
593 let sources = vec![
594 make_source("db", "my_view", "db", "users", "id"),
595 make_source("db", "my_view", "db", "users", "name"),
596 ];
597 resolve_view_nullability(&mut views, &sources, &tables);
598 assert!(!views[0].columns[0].is_nullable);
599 assert!(views[0].columns[1].is_nullable);
600 }
601
602 #[test]
603 fn test_resolve_no_match_stays_nullable() {
604 let tables = vec![make_table_with_nullability(
605 "db",
606 "users",
607 vec![("id", false)],
608 )];
609 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
610 let sources = vec![];
611 resolve_view_nullability(&mut views, &sources, &tables);
612 assert!(views[0].columns[0].is_nullable);
613 }
614
615 #[test]
616 fn test_resolve_empty_sources() {
617 let tables = vec![];
618 let mut views = vec![make_view("db", "my_view", vec!["id"])];
619 resolve_view_nullability(&mut views, &[], &tables);
620 assert!(views[0].columns[0].is_nullable);
621 }
622
623 fn make_table_with_pk(
626 schema: &str,
627 name: &str,
628 columns: Vec<(&str, bool)>,
629 ) -> TableInfo {
630 TableInfo {
631 schema_name: schema.to_string(),
632 name: name.to_string(),
633 columns: columns
634 .into_iter()
635 .enumerate()
636 .map(|(i, (col, is_pk))| ColumnInfo {
637 name: col.to_string(),
638 data_type: "varchar".to_string(),
639 udt_name: "varchar(255)".to_string(),
640 is_nullable: false,
641 is_primary_key: is_pk,
642 ordinal_position: i as i32,
643 schema_name: schema.to_string(),
644 column_default: None,
645 })
646 .collect(),
647 }
648 }
649
650 #[test]
651 fn test_resolve_pk_column() {
652 let tables = vec![make_table_with_pk("db", "users", vec![("id", true), ("name", false)])];
653 let mut views = vec![make_view("db", "my_view", vec!["id", "name"])];
654 let sources = vec![
655 make_source("db", "my_view", "db", "users", "id"),
656 make_source("db", "my_view", "db", "users", "name"),
657 ];
658 resolve_view_primary_keys(&mut views, &sources, &tables);
659 assert!(views[0].columns[0].is_primary_key);
660 assert!(!views[0].columns[1].is_primary_key);
661 }
662
663 #[test]
664 fn test_resolve_pk_no_sources() {
665 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
666 let mut views = vec![make_view("db", "my_view", vec!["id"])];
667 resolve_view_primary_keys(&mut views, &[], &tables);
668 assert!(!views[0].columns[0].is_primary_key);
669 }
670
671 #[test]
672 fn test_resolve_pk_no_match() {
673 let tables = vec![make_table_with_pk("db", "users", vec![("id", true)])];
674 let mut views = vec![make_view("db", "my_view", vec!["computed"])];
675 let sources = vec![];
676 resolve_view_primary_keys(&mut views, &sources, &tables);
677 assert!(!views[0].columns[0].is_primary_key);
678 }
679}