1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::SqlitePool;
5
6use super::{ColumnInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(pool: &SqlitePool, include_views: bool) -> Result<SchemaInfo> {
9 let tables = fetch_tables(pool).await?;
10 let mut views = if include_views {
11 fetch_views(pool).await?
12 } else {
13 Vec::new()
14 };
15
16 if !views.is_empty() {
17 resolve_view_nullability(&mut views, &tables);
18 resolve_view_primary_keys(&mut views, &tables);
19 }
20
21 Ok(SchemaInfo {
22 tables,
23 views,
24 enums: Vec::new(),
25 composite_types: Vec::new(),
26 domains: Vec::new(),
27 })
28}
29
30async fn fetch_tables(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
31 let table_names: Vec<(String,)> = sqlx::query_as(
32 "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name",
33 )
34 .fetch_all(pool)
35 .await?;
36
37 let mut tables = Vec::new();
38
39 for (table_name,) in table_names {
40 let columns = fetch_columns(pool, &table_name).await?;
41 tables.push(TableInfo {
42 schema_name: "main".to_string(),
43 name: table_name,
44 columns,
45 });
46 }
47
48 Ok(tables)
49}
50
51async fn fetch_views(pool: &SqlitePool) -> Result<Vec<TableInfo>> {
52 let view_names: Vec<(String,)> = sqlx::query_as(
53 "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY name",
54 )
55 .fetch_all(pool)
56 .await?;
57
58 let mut views = Vec::new();
59
60 for (view_name,) in view_names {
61 let columns = fetch_columns(pool, &view_name).await?;
62 views.push(TableInfo {
63 schema_name: "main".to_string(),
64 name: view_name,
65 columns,
66 });
67 }
68
69 Ok(views)
70}
71
72async fn fetch_columns(pool: &SqlitePool, table_name: &str) -> Result<Vec<ColumnInfo>> {
73 let pragma_query = format!("PRAGMA table_info(\"{}\")", table_name.replace('"', "\"\""));
75 let rows: Vec<(i32, String, String, bool, Option<String>, i32)> =
76 sqlx::query_as(&pragma_query).fetch_all(pool).await?;
77
78 Ok(rows
79 .into_iter()
80 .map(|(cid, name, declared_type, notnull, dflt_value, pk)| {
81 let upper = declared_type.to_uppercase();
82 ColumnInfo {
83 name,
84 data_type: upper.clone(),
85 udt_name: upper,
86 is_nullable: !notnull,
87 is_primary_key: pk > 0,
88 ordinal_position: cid,
89 schema_name: "main".to_string(),
90 column_default: dflt_value,
91 }
92 })
93 .collect())
94}
95
96fn resolve_view_nullability(views: &mut [TableInfo], tables: &[TableInfo]) {
99 let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
101 for table in tables {
102 for col in &table.columns {
103 col_lookup.entry(&col.name).or_default().push(col.is_nullable);
104 }
105 }
106
107 for view in views.iter_mut() {
108 for col in view.columns.iter_mut() {
109 if let Some(nullable_flags) = col_lookup.get(col.name.as_str()) {
110 if nullable_flags.len() == 1 && !nullable_flags[0] {
113 col.is_nullable = false;
114 }
115 }
116 }
117 }
118}
119
120fn resolve_view_primary_keys(views: &mut [TableInfo], tables: &[TableInfo]) {
123 let mut col_lookup: HashMap<&str, Vec<bool>> = HashMap::new();
125 for table in tables {
126 for col in &table.columns {
127 col_lookup.entry(&col.name).or_default().push(col.is_primary_key);
128 }
129 }
130
131 for view in views.iter_mut() {
132 for col in view.columns.iter_mut() {
133 if let Some(pk_flags) = col_lookup.get(col.name.as_str()) {
134 if pk_flags.len() == 1 && pk_flags[0] {
137 col.is_primary_key = true;
138 }
139 }
140 }
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 fn make_table(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
149 TableInfo {
150 schema_name: "main".to_string(),
151 name: name.to_string(),
152 columns: columns
153 .into_iter()
154 .enumerate()
155 .map(|(i, (col, nullable))| ColumnInfo {
156 name: col.to_string(),
157 data_type: "TEXT".to_string(),
158 udt_name: "TEXT".to_string(),
159 is_nullable: nullable,
160 is_primary_key: false,
161 ordinal_position: i as i32,
162 schema_name: "main".to_string(),
163 column_default: None,
164 })
165 .collect(),
166 }
167 }
168
169 fn make_view(name: &str, columns: Vec<&str>) -> TableInfo {
170 TableInfo {
171 schema_name: "main".to_string(),
172 name: name.to_string(),
173 columns: columns
174 .into_iter()
175 .enumerate()
176 .map(|(i, col)| ColumnInfo {
177 name: col.to_string(),
178 data_type: "TEXT".to_string(),
179 udt_name: "TEXT".to_string(),
180 is_nullable: true,
181 is_primary_key: false,
182 ordinal_position: i as i32,
183 schema_name: "main".to_string(),
184 column_default: None,
185 })
186 .collect(),
187 }
188 }
189
190 #[test]
191 fn test_resolve_unique_not_null() {
192 let tables = vec![make_table("users", vec![("id", false), ("name", false)])];
193 let mut views = vec![make_view("my_view", vec!["id", "name"])];
194 resolve_view_nullability(&mut views, &tables);
195 assert!(!views[0].columns[0].is_nullable);
196 assert!(!views[0].columns[1].is_nullable);
197 }
198
199 #[test]
200 fn test_resolve_nullable_source() {
201 let tables = vec![make_table("users", vec![("id", false), ("name", true)])];
202 let mut views = vec![make_view("my_view", vec!["id", "name"])];
203 resolve_view_nullability(&mut views, &tables);
204 assert!(!views[0].columns[0].is_nullable);
205 assert!(views[0].columns[1].is_nullable);
206 }
207
208 #[test]
209 fn test_resolve_ambiguous_stays_nullable() {
210 let tables = vec![
212 make_table("users", vec![("id", false)]),
213 make_table("orders", vec![("id", false)]),
214 ];
215 let mut views = vec![make_view("my_view", vec!["id"])];
216 resolve_view_nullability(&mut views, &tables);
217 assert!(views[0].columns[0].is_nullable);
218 }
219
220 #[test]
221 fn test_resolve_no_match() {
222 let tables = vec![make_table("users", vec![("id", false)])];
223 let mut views = vec![make_view("my_view", vec!["computed"])];
224 resolve_view_nullability(&mut views, &tables);
225 assert!(views[0].columns[0].is_nullable);
226 }
227
228 #[test]
229 fn test_resolve_empty_tables() {
230 let mut views = vec![make_view("my_view", vec!["id"])];
231 resolve_view_nullability(&mut views, &[]);
232 assert!(views[0].columns[0].is_nullable);
233 }
234
235 fn make_table_with_pk(name: &str, columns: Vec<(&str, bool)>) -> TableInfo {
238 TableInfo {
239 schema_name: "main".to_string(),
240 name: name.to_string(),
241 columns: columns
242 .into_iter()
243 .enumerate()
244 .map(|(i, (col, is_pk))| ColumnInfo {
245 name: col.to_string(),
246 data_type: "TEXT".to_string(),
247 udt_name: "TEXT".to_string(),
248 is_nullable: false,
249 is_primary_key: is_pk,
250 ordinal_position: i as i32,
251 schema_name: "main".to_string(),
252 column_default: None,
253 })
254 .collect(),
255 }
256 }
257
258 #[test]
259 fn test_resolve_pk_unique_match() {
260 let tables = vec![make_table_with_pk("users", vec![("id", true), ("name", false)])];
261 let mut views = vec![make_view("my_view", vec!["id", "name"])];
262 resolve_view_primary_keys(&mut views, &tables);
263 assert!(views[0].columns[0].is_primary_key);
264 assert!(!views[0].columns[1].is_primary_key);
265 }
266
267 #[test]
268 fn test_resolve_pk_ambiguous() {
269 let tables = vec![
271 make_table_with_pk("users", vec![("id", true)]),
272 make_table_with_pk("orders", vec![("id", true)]),
273 ];
274 let mut views = vec![make_view("my_view", vec!["id"])];
275 resolve_view_primary_keys(&mut views, &tables);
276 assert!(!views[0].columns[0].is_primary_key);
277 }
278
279 #[test]
280 fn test_resolve_pk_no_match() {
281 let tables = vec![make_table_with_pk("users", vec![("id", true)])];
282 let mut views = vec![make_view("my_view", vec!["computed"])];
283 resolve_view_primary_keys(&mut views, &tables);
284 assert!(!views[0].columns[0].is_primary_key);
285 }
286
287 #[test]
288 fn test_resolve_pk_empty_tables() {
289 let mut views = vec![make_view("my_view", vec!["id"])];
290 resolve_view_primary_keys(&mut views, &[]);
291 assert!(!views[0].columns[0].is_primary_key);
292 }
293}