Skip to main content

sqlx_gen/introspect/
sqlite.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9    let tables = fetch_tables(pool).await?;
10    let mut views = if include_views {
11        fetch_views(pool).await?
12    } else {
13        Vec::new()
14    };
15
16    if !views.is_empty() {
17        resolve_view_nullability(&mut views, &tables);
18        resolve_view_primary_keys(&mut views, &tables);
19    }
20
21    Ok(SchemaInfo {
22        tables,
23        views,
24        enums: Vec::new(),
25        composite_types: Vec::new(),
26        domains: Vec::new(),
27    })
28}
29
30async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
31    let table_names: Vec<(String,)> = sqlx::query_as(
32        "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
33    )
34    .fetch_all(pool)
35    .await?;
36
37    let mut tables = Vec::new();
38
39    for (table_name,) in table_names {
40        let columns = fetch_columns(pool, &table_name).await?;
41        tables.push(TableInfo {
42            schema_name: "main".to_string(),
43            name: table_name,
44            columns,
45        });
46    }
47
48    Ok(tables)
49}
50
51async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
52    let view_names: Vec<(String,)> = sqlx::query_as(
53        "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name",
54    )
55    .fetch_all(pool)
56    .await?;
57
58    let mut views = Vec::new();
59
60    for (view_name,) in view_names {
61        let columns = fetch_columns(pool, &view_name).await?;
62        views.push(TableInfo {
63            schema_name: "main".to_string(),
64            name: view_name,
65            columns,
66        });
67    }
68
69    Ok(views)
70}
71
72async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
73    // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
74    let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
75    let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
76        sqlx::query_as(&pragma_query).fetch_all(pool).await?;
77
78    Ok(rows
79        .into_iter()
80        .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
81            let upper = declared_type.to_uppercase();
82            ColumnInfo {
83                name,
84                data_type: upper.clone(),
85                udt_name: upper,
86                is_nullable: !notnull,
87                is_primary_key: pk > 0,
88                ordinal_position: cid,
89                schema_name: "main".to_string(),
90                column_default: dflt_value,
91            }
92        })
93        .collect())
94}
95
96/// Resolve view column nullability by matching column names against introspected tables.
97/// If a column name is found in exactly one table and is NOT NULL, propagate that.
98fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
99    // Build lookup: column_name -> Vec<is_nullable>
100    let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
101    for table in tables {
102        for col in &table.columns {
103            col_lookup.entry(&col.name).or_default().push(col.is_nullable);
104        }
105    }
106
107    for view in views.iter_mut() {
108        for col in view.columns.iter_mut() {
109            if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
110                // Only resolve if column name appears in exactly one table
111                // and that column is NOT nullable
112                if nullable_flags.len() == 1 && !nullable_flags[0] {
113                    col.is_nullable = false;
114                }
115            }
116        }
117    }
118}
119
120/// Resolve view column primary keys by matching column names against introspected tables.
121/// If a column name is found in exactly one table and is a PK, propagate that.
122fn resolve_view_primary_keys(views: &mut [TableInfo], tables: &[TableInfo]) {
123    // Build lookup: column_name -> Vec<is_primary_key>
124    let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
125    for table in tables {
126        for col in &table.columns {
127            col_lookup.entry(&col.name).or_default().push(col.is_primary_key);
128        }
129    }
130
131    for view in views.iter_mut() {
132        for col in view.columns.iter_mut() {
133            if let Some(pk_flags) = col_lookup.get(col.name.as_str()) {
134                // Only resolve if column name appears in exactly one table
135                // and that column is a PK
136                if pk_flags.len() == 1 && pk_flags[0] {
137                    col.is_primary_key = true;
138                }
139            }
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
149        TableInfo {
150            schema_name: "main".to_string(),
151            name: name.to_string(),
152            columns: columns
153                .into_iter()
154                .enumerate()
155                .map(|(i, (col, nullable))| ColumnInfo {
156                    name: col.to_string(),
157                    data_type: "TEXT".to_string(),
158                    udt_name: "TEXT".to_string(),
159                    is_nullable: nullable,
160                    is_primary_key: false,
161                    ordinal_position: i as i32,
162                    schema_name: "main".to_string(),
163                    column_default: None,
164                })
165                .collect(),
166        }
167    }
168
169    fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
170        TableInfo {
171            schema_name: "main".to_string(),
172            name: name.to_string(),
173            columns: columns
174                .into_iter()
175                .enumerate()
176                .map(|(i, col)| ColumnInfo {
177                    name: col.to_string(),
178                    data_type: "TEXT".to_string(),
179                    udt_name: "TEXT".to_string(),
180                    is_nullable: true,
181                    is_primary_key: false,
182                    ordinal_position: i as i32,
183                    schema_name: "main".to_string(),
184                    column_default: None,
185                })
186                .collect(),
187        }
188    }
189
190    #[test]
191    fn test_resolve_unique_not_null() {
192        let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
193        let mut views = vec![make_view("my_view", vec!["id", "name"])];
194        resolve_view_nullability(&mut views, &tables);
195        assert!(!views[0].columns[0].is_nullable);
196        assert!(!views[0].columns[1].is_nullable);
197    }
198
199    #[test]
200    fn test_resolve_nullable_source() {
201        let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
202        let mut views = vec![make_view("my_view", vec!["id", "name"])];
203        resolve_view_nullability(&mut views, &tables);
204        assert!(!views[0].columns[0].is_nullable);
205        assert!(views[0].columns[1].is_nullable);
206    }
207
208    #[test]
209    fn test_resolve_ambiguous_stays_nullable() {
210        // "id" appears in two tables — ambiguous, stay nullable
211        let tables = vec![
212            make_table("users", vec![("id", false)]),
213            make_table("orders", vec![("id", false)]),
214        ];
215        let mut views = vec![make_view("my_view", vec!["id"])];
216        resolve_view_nullability(&mut views, &tables);
217        assert!(views[0].columns[0].is_nullable);
218    }
219
220    #[test]
221    fn test_resolve_no_match() {
222        let tables = vec![make_table("users", vec![("id", false)])];
223        let mut views = vec![make_view("my_view", vec!["computed"])];
224        resolve_view_nullability(&mut views, &tables);
225        assert!(views[0].columns[0].is_nullable);
226    }
227
228    #[test]
229    fn test_resolve_empty_tables() {
230        let mut views = vec![make_view("my_view", vec!["id"])];
231        resolve_view_nullability(&mut views, &[]);
232        assert!(views[0].columns[0].is_nullable);
233    }
234
235    // ========== resolve_view_primary_keys ==========
236
237    fn make_table_with_pk(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
238        TableInfo {
239            schema_name: "main".to_string(),
240            name: name.to_string(),
241            columns: columns
242                .into_iter()
243                .enumerate()
244                .map(|(i, (col, is_pk))| ColumnInfo {
245                    name: col.to_string(),
246                    data_type: "TEXT".to_string(),
247                    udt_name: "TEXT".to_string(),
248                    is_nullable: false,
249                    is_primary_key: is_pk,
250                    ordinal_position: i as i32,
251                    schema_name: "main".to_string(),
252                    column_default: None,
253                })
254                .collect(),
255        }
256    }
257
258    #[test]
259    fn test_resolve_pk_unique_match() {
260        let tables = vec![make_table_with_pk("users", vec![("id", true), ("name", false)])];
261        let mut views = vec![make_view("my_view", vec!["id", "name"])];
262        resolve_view_primary_keys(&mut views, &tables);
263        assert!(views[0].columns[0].is_primary_key);
264        assert!(!views[0].columns[1].is_primary_key);
265    }
266
267    #[test]
268    fn test_resolve_pk_ambiguous() {
269        // "id" appears in two tables — ambiguous, don't mark as PK
270        let tables = vec![
271            make_table_with_pk("users", vec![("id", true)]),
272            make_table_with_pk("orders", vec![("id", true)]),
273        ];
274        let mut views = vec![make_view("my_view", vec!["id"])];
275        resolve_view_primary_keys(&mut views, &tables);
276        assert!(!views[0].columns[0].is_primary_key);
277    }
278
279    #[test]
280    fn test_resolve_pk_no_match() {
281        let tables = vec![make_table_with_pk("users", vec![("id", true)])];
282        let mut views = vec![make_view("my_view", vec!["computed"])];
283        resolve_view_primary_keys(&mut views, &tables);
284        assert!(!views[0].columns[0].is_primary_key);
285    }
286
287    #[test]
288    fn test_resolve_pk_empty_tables() {
289        let mut views = vec![make_view("my_view", vec!["id"])];
290        resolve_view_primary_keys(&mut views, &[]);
291        assert!(!views[0].columns[0].is_primary_key);
292    }
293}