1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9 let mut tables = fetch_tables(pool).await?;
10 let mut views = if include_views {
11 fetch_views(pool).await?
12 } else {
13 Vec::new()
14 };
15
16 if !views.is_empty() {
17 resolve_view_nullability(&mut views, &tables);
18 resolve_view_primary_keys(&mut views, &tables);
19 }
20
21 let enums = extract_check_enums(pool, &mut tables).await?;
22
23 Ok(SchemaInfo {
24 tables,
25 views,
26 enums,
27 composite_types: Vec::new(),
28 domains: Vec::new(),
29 })
30}
31
32async fn extract_check_enums(pool: &SqlitePool, tables: &mut [TableInfo]) -> Result<Vec<EnumInfo>> {
38 let mut enums = Vec::new();
39
40 for table in tables.iter_mut() {
41 let sql: Option<(Option<String>,)> =
42 sqlx::query_as("SELECT sql FROM sqlite_master WHERE type = 'table' AND name = ?")
43 .bind(&table.name)
44 .fetch_optional(pool)
45 .await?;
46 let Some((Some(ddl),)) = sql else { continue };
47
48 for col in table.columns.iter_mut() {
49 if let Some(variants) = parse_check_in_variants(&ddl, &col.name) {
50 if variants.is_empty() {
51 continue;
52 }
53 let enum_name = format!("{}_{}_enum", table.name, col.name);
54 col.udt_name = enum_name.clone();
55 enums.push(EnumInfo {
56 schema_name: "main".to_string(),
57 name: enum_name,
58 variants,
59 default_variant: None,
60 });
61 }
62 }
63 }
64
65 Ok(enums)
66}
67
68fn parse_check_in_variants(ddl: &str, column: &str) -> Option<Vec<String>> {
72 let lower_ddl = ddl.to_ascii_lowercase();
73 let lower_col = column.to_ascii_lowercase();
74 let mut search_from = 0usize;
75
76 while let Some(rel_check) = lower_ddl[search_from..].find("check") {
77 let check_pos = search_from + rel_check;
78 let after_check = &ddl[check_pos + 5..];
79 let after_check_lower = &lower_ddl[check_pos + 5..];
80
81 let open_rel = after_check.find('(')?;
82 let mut depth = 1i32;
83 let mut idx = open_rel + 1;
84 let bytes = after_check.as_bytes();
85 while idx < bytes.len() && depth > 0 {
86 match bytes[idx] {
87 b'(' => depth += 1,
88 b')' => depth -= 1,
89 b'\'' => {
90 idx += 1;
91 while idx < bytes.len() && bytes[idx] != b'\'' {
92 idx += 1;
93 }
94 }
95 _ => {}
96 }
97 idx += 1;
98 }
99 if depth != 0 {
100 return None;
101 }
102 let body = &after_check[open_rel + 1..idx - 1];
103 let body_lower = &after_check_lower[open_rel + 1..idx - 1];
104
105 search_from = check_pos + 5 + idx;
106
107 if !body_lower.contains(&lower_col) || !body_lower.contains(" in ") {
108 continue;
109 }
110
111 if let Some(in_pos) = body_lower.find(" in ") {
112 let list_start = body[in_pos..].find('(')?;
113 let list_body = &body[in_pos + list_start + 1..];
114 let mut variants = Vec::new();
115 let bytes = list_body.as_bytes();
116 let mut i = 0;
117 while i < bytes.len() {
118 if bytes[i] == b'\'' {
119 let start = i + 1;
120 let mut j = start;
121 while j < bytes.len() && bytes[j] != b'\'' {
122 j += 1;
123 }
124 variants.push(list_body[start..j].to_string());
125 i = j + 1;
126 } else if bytes[i] == b')' {
127 break;
128 } else {
129 i += 1;
130 }
131 }
132 return Some(variants);
133 }
134 }
135
136 None
137}
138
139async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
140 let table_names: Vec<(String,)> = sqlx::query_as(
141 "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
142 )
143 .fetch_all(pool)
144 .await?;
145
146 let mut tables = Vec::new();
147
148 for (table_name,) in table_names {
149 let columns = fetch_columns(pool, &table_name).await?;
150 tables.push(TableInfo {
151 schema_name: "main".to_string(),
152 name: table_name,
153 columns,
154 });
155 }
156
157 Ok(tables)
158}
159
160async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
161 let view_names: Vec<(String,)> =
162 sqlx::query_as("SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name")
163 .fetch_all(pool)
164 .await?;
165
166 let mut views = Vec::new();
167
168 for (view_name,) in view_names {
169 let columns = fetch_columns(pool, &view_name).await?;
170 views.push(TableInfo {
171 schema_name: "main".to_string(),
172 name: view_name,
173 columns,
174 });
175 }
176
177 Ok(views)
178}
179
180async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
181 let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
183 let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
184 sqlx::query_as(&pragma_query).fetch_all(pool).await?;
185
186 Ok(rows
187 .into_iter()
188 .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
189 let upper = declared_type.to_uppercase();
190 ColumnInfo {
191 name,
192 data_type: upper.clone(),
193 udt_name: upper,
194 udt_schema: None,
195 is_nullable: !notnull,
196 is_primary_key: pk > 0,
197 ordinal_position: cid,
198 schema_name: "main".to_string(),
199 column_default: dflt_value,
200 }
201 })
202 .collect())
203}
204
205fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
208 let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
210 for table in tables {
211 for col in &table.columns {
212 col_lookup
213 .entry(&col.name)
214 .or_default()
215 .push(col.is_nullable);
216 }
217 }
218
219 for view in views.iter_mut() {
220 for col in view.columns.iter_mut() {
221 if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
222 if nullable_flags.len() == 1 && !nullable_flags[0] {
225 col.is_nullable = false;
226 }
227 }
228 }
229 }
230}
231
232fn resolve_view_primary_keys(views: &mut [TableInfo], tables: &[TableInfo]) {
235 let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
237 for table in tables {
238 for col in &table.columns {
239 col_lookup
240 .entry(&col.name)
241 .or_default()
242 .push(col.is_primary_key);
243 }
244 }
245
246 for view in views.iter_mut() {
247 for col in view.columns.iter_mut() {
248 if let Some(pk_flags) = col_lookup.get(col.name.as_str()) {
249 if pk_flags.len() == 1 && pk_flags[0] {
252 col.is_primary_key = true;
253 }
254 }
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
264 TableInfo {
265 schema_name: "main".to_string(),
266 name: name.to_string(),
267 columns: columns
268 .into_iter()
269 .enumerate()
270 .map(|(i, (col, nullable))| ColumnInfo {
271 name: col.to_string(),
272 data_type: "TEXT".to_string(),
273 udt_name: "TEXT".to_string(),
274 is_nullable: nullable,
275 is_primary_key: false,
276 ordinal_position: i as i32,
277 schema_name: "main".to_string(),
278 udt_schema: None,
279 column_default: None,
280 })
281 .collect(),
282 }
283 }
284
285 fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
286 TableInfo {
287 schema_name: "main".to_string(),
288 name: name.to_string(),
289 columns: columns
290 .into_iter()
291 .enumerate()
292 .map(|(i, col)| ColumnInfo {
293 name: col.to_string(),
294 data_type: "TEXT".to_string(),
295 udt_name: "TEXT".to_string(),
296 is_nullable: true,
297 is_primary_key: false,
298 ordinal_position: i as i32,
299 schema_name: "main".to_string(),
300 udt_schema: None,
301 column_default: None,
302 })
303 .collect(),
304 }
305 }
306
307 #[test]
308 fn test_resolve_unique_not_null() {
309 let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
310 let mut views = vec![make_view("my_view", vec!["id", "name"])];
311 resolve_view_nullability(&mut views, &tables);
312 assert!(!views[0].columns[0].is_nullable);
313 assert!(!views[0].columns[1].is_nullable);
314 }
315
316 #[test]
317 fn test_resolve_nullable_source() {
318 let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
319 let mut views = vec![make_view("my_view", vec!["id", "name"])];
320 resolve_view_nullability(&mut views, &tables);
321 assert!(!views[0].columns[0].is_nullable);
322 assert!(views[0].columns[1].is_nullable);
323 }
324
325 #[test]
326 fn test_resolve_ambiguous_stays_nullable() {
327 let tables = vec![
329 make_table("users", vec![("id", false)]),
330 make_table("orders", vec![("id", false)]),
331 ];
332 let mut views = vec![make_view("my_view", vec!["id"])];
333 resolve_view_nullability(&mut views, &tables);
334 assert!(views[0].columns[0].is_nullable);
335 }
336
337 #[test]
338 fn test_resolve_no_match() {
339 let tables = vec![make_table("users", vec![("id", false)])];
340 let mut views = vec![make_view("my_view", vec!["computed"])];
341 resolve_view_nullability(&mut views, &tables);
342 assert!(views[0].columns[0].is_nullable);
343 }
344
345 #[test]
346 fn test_resolve_empty_tables() {
347 let mut views = vec![make_view("my_view", vec!["id"])];
348 resolve_view_nullability(&mut views, &[]);
349 assert!(views[0].columns[0].is_nullable);
350 }
351
352 fn make_table_with_pk(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
355 TableInfo {
356 schema_name: "main".to_string(),
357 name: name.to_string(),
358 columns: columns
359 .into_iter()
360 .enumerate()
361 .map(|(i, (col, is_pk))| ColumnInfo {
362 name: col.to_string(),
363 data_type: "TEXT".to_string(),
364 udt_name: "TEXT".to_string(),
365 is_nullable: false,
366 is_primary_key: is_pk,
367 ordinal_position: i as i32,
368 schema_name: "main".to_string(),
369 udt_schema: None,
370 column_default: None,
371 })
372 .collect(),
373 }
374 }
375
376 #[test]
377 fn test_resolve_pk_unique_match() {
378 let tables = vec![make_table_with_pk(
379 "users",
380 vec![("id", true), ("name", false)],
381 )];
382 let mut views = vec![make_view("my_view", vec!["id", "name"])];
383 resolve_view_primary_keys(&mut views, &tables);
384 assert!(views[0].columns[0].is_primary_key);
385 assert!(!views[0].columns[1].is_primary_key);
386 }
387
388 #[test]
389 fn test_resolve_pk_ambiguous() {
390 let tables = vec![
392 make_table_with_pk("users", vec![("id", true)]),
393 make_table_with_pk("orders", vec![("id", true)]),
394 ];
395 let mut views = vec![make_view("my_view", vec!["id"])];
396 resolve_view_primary_keys(&mut views, &tables);
397 assert!(!views[0].columns[0].is_primary_key);
398 }
399
400 #[test]
401 fn test_resolve_pk_no_match() {
402 let tables = vec![make_table_with_pk("users", vec![("id", true)])];
403 let mut views = vec![make_view("my_view", vec!["computed"])];
404 resolve_view_primary_keys(&mut views, &tables);
405 assert!(!views[0].columns[0].is_primary_key);
406 }
407
408 #[test]
409 fn test_resolve_pk_empty_tables() {
410 let mut views = vec![make_view("my_view", vec!["id"])];
411 resolve_view_primary_keys(&mut views, &[]);
412 assert!(!views[0].columns[0].is_primary_key);
413 }
414
415 #[test]
418 fn test_parse_check_in_simple() {
419 let ddl = "CREATE TABLE t (id INTEGER PRIMARY KEY, status TEXT CHECK (status IN ('active', 'inactive')) NOT NULL)";
420 assert_eq!(
421 parse_check_in_variants(ddl, "status"),
422 Some(vec!["active".to_string(), "inactive".to_string()])
423 );
424 }
425
426 #[test]
427 fn test_parse_check_in_three_variants() {
428 let ddl = "CREATE TABLE t (priority TEXT CHECK (priority IN ('low','medium','high')))";
429 assert_eq!(
430 parse_check_in_variants(ddl, "priority"),
431 Some(vec![
432 "low".to_string(),
433 "medium".to_string(),
434 "high".to_string()
435 ])
436 );
437 }
438
439 #[test]
440 fn test_parse_check_in_returns_none_for_other_column() {
441 let ddl = "CREATE TABLE t (status TEXT CHECK (status IN ('a','b')))";
442 assert_eq!(parse_check_in_variants(ddl, "other"), None);
443 }
444
445 #[test]
446 fn test_parse_check_in_returns_none_without_check() {
447 let ddl = "CREATE TABLE t (status TEXT)";
448 assert_eq!(parse_check_in_variants(ddl, "status"), None);
449 }
450
451 #[test]
452 fn test_parse_check_in_case_insensitive_keyword() {
453 let ddl = "CREATE TABLE t (status TEXT check (Status in ('a','b')))";
454 assert_eq!(
455 parse_check_in_variants(ddl, "status"),
456 Some(vec!["a".to_string(), "b".to_string()])
457 );
458 }
459}