Skip to main content

sqlx_gen/introspect/
sqlite.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9    let tables = fetch_tables(pool).await?;
10    let mut views = if include_views {
11        fetch_views(pool).await?
12    } else {
13        Vec::new()
14    };
15
16    if !views.is_empty() {
17        resolve_view_nullability(&mut views, &tables);
18    }
19
20    Ok(SchemaInfo {
21        tables,
22        views,
23        enums: Vec::new(),
24        composite_types: Vec::new(),
25        domains: Vec::new(),
26    })
27}
28
29async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
30    let table_names: Vec<(String,)> = sqlx::query_as(
31        "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
32    )
33    .fetch_all(pool)
34    .await?;
35
36    let mut tables = Vec::new();
37
38    for (table_name,) in table_names {
39        let columns = fetch_columns(pool, &table_name).await?;
40        tables.push(TableInfo {
41            schema_name: "main".to_string(),
42            name: table_name,
43            columns,
44        });
45    }
46
47    Ok(tables)
48}
49
50async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
51    let view_names: Vec<(String,)> = sqlx::query_as(
52        "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name",
53    )
54    .fetch_all(pool)
55    .await?;
56
57    let mut views = Vec::new();
58
59    for (view_name,) in view_names {
60        let columns = fetch_columns(pool, &view_name).await?;
61        views.push(TableInfo {
62            schema_name: "main".to_string(),
63            name: view_name,
64            columns,
65        });
66    }
67
68    Ok(views)
69}
70
71async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
72    // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
73    let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
74    let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
75        sqlx::query_as(&pragma_query).fetch_all(pool).await?;
76
77    Ok(rows
78        .into_iter()
79        .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
80            let upper = declared_type.to_uppercase();
81            ColumnInfo {
82                name,
83                data_type: upper.clone(),
84                udt_name: upper,
85                is_nullable: !notnull,
86                is_primary_key: pk > 0,
87                ordinal_position: cid,
88                schema_name: "main".to_string(),
89                column_default: dflt_value,
90            }
91        })
92        .collect())
93}
94
95/// Resolve view column nullability by matching column names against introspected tables.
96/// If a column name is found in exactly one table and is NOT NULL, propagate that.
97fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
98    // Build lookup: column_name -> Vec<is_nullable>
99    let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
100    for table in tables {
101        for col in &table.columns {
102            col_lookup.entry(&col.name).or_default().push(col.is_nullable);
103        }
104    }
105
106    for view in views.iter_mut() {
107        for col in view.columns.iter_mut() {
108            if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
109                // Only resolve if column name appears in exactly one table
110                // and that column is NOT nullable
111                if nullable_flags.len() == 1 && !nullable_flags[0] {
112                    col.is_nullable = false;
113                }
114            }
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
124        TableInfo {
125            schema_name: "main".to_string(),
126            name: name.to_string(),
127            columns: columns
128                .into_iter()
129                .enumerate()
130                .map(|(i, (col, nullable))| ColumnInfo {
131                    name: col.to_string(),
132                    data_type: "TEXT".to_string(),
133                    udt_name: "TEXT".to_string(),
134                    is_nullable: nullable,
135                    is_primary_key: false,
136                    ordinal_position: i as i32,
137                    schema_name: "main".to_string(),
138                    column_default: None,
139                })
140                .collect(),
141        }
142    }
143
144    fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
145        TableInfo {
146            schema_name: "main".to_string(),
147            name: name.to_string(),
148            columns: columns
149                .into_iter()
150                .enumerate()
151                .map(|(i, col)| ColumnInfo {
152                    name: col.to_string(),
153                    data_type: "TEXT".to_string(),
154                    udt_name: "TEXT".to_string(),
155                    is_nullable: true,
156                    is_primary_key: false,
157                    ordinal_position: i as i32,
158                    schema_name: "main".to_string(),
159                    column_default: None,
160                })
161                .collect(),
162        }
163    }
164
165    #[test]
166    fn test_resolve_unique_not_null() {
167        let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
168        let mut views = vec![make_view("my_view", vec!["id", "name"])];
169        resolve_view_nullability(&mut views, &tables);
170        assert!(!views[0].columns[0].is_nullable);
171        assert!(!views[0].columns[1].is_nullable);
172    }
173
174    #[test]
175    fn test_resolve_nullable_source() {
176        let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
177        let mut views = vec![make_view("my_view", vec!["id", "name"])];
178        resolve_view_nullability(&mut views, &tables);
179        assert!(!views[0].columns[0].is_nullable);
180        assert!(views[0].columns[1].is_nullable);
181    }
182
183    #[test]
184    fn test_resolve_ambiguous_stays_nullable() {
185        // "id" appears in two tables — ambiguous, stay nullable
186        let tables = vec![
187            make_table("users", vec![("id", false)]),
188            make_table("orders", vec![("id", false)]),
189        ];
190        let mut views = vec![make_view("my_view", vec!["id"])];
191        resolve_view_nullability(&mut views, &tables);
192        assert!(views[0].columns[0].is_nullable);
193    }
194
195    #[test]
196    fn test_resolve_no_match() {
197        let tables = vec![make_table("users", vec![("id", false)])];
198        let mut views = vec![make_view("my_view", vec!["computed"])];
199        resolve_view_nullability(&mut views, &tables);
200        assert!(views[0].columns[0].is_nullable);
201    }
202
203    #[test]
204    fn test_resolve_empty_tables() {
205        let mut views = vec![make_view("my_view", vec!["id"])];
206        resolve_view_nullability(&mut views, &[]);
207        assert!(views[0].columns[0].is_nullable);
208    }
209}