1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::PgPool;
5
6use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9 pool: &PgPool,
10 schemas: &[String],
11 include_views: bool,
12) -> Result<SchemaInfo> {
13 let tables = fetch_tables(pool, schemas).await?;
14 let mut views = if include_views {
15 fetch_views(pool, schemas).await?
16 } else {
17 Vec::new()
18 };
19
20 if !views.is_empty() {
21 let nullability_info = fetch_view_column_nullability(pool, schemas).await?;
22 resolve_view_nullability(&mut views, &nullability_info);
23 }
24
25 let enums = fetch_enums(pool, schemas).await?;
26 let composite_types = fetch_composite_types(pool, schemas).await?;
27 let domains = fetch_domains(pool, schemas).await?;
28
29 Ok(SchemaInfo {
30 tables,
31 views,
32 enums,
33 composite_types,
34 domains,
35 })
36}
37
38async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
39 let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, bool, Option<String>)>(
40 r#"
41 SELECT
42 c.table_schema,
43 c.table_name,
44 c.column_name,
45 c.data_type,
46 COALESCE(c.udt_name, c.data_type) as udt_name,
47 c.is_nullable,
48 c.ordinal_position,
49 CASE WHEN kcu.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key,
50 c.column_default
51 FROM information_schema.columns c
52 JOIN information_schema.tables t
53 ON t.table_schema = c.table_schema
54 AND t.table_name = c.table_name
55 AND t.table_type = 'BASE TABLE'
56 LEFT JOIN information_schema.table_constraints tc
57 ON tc.table_schema = c.table_schema
58 AND tc.table_name = c.table_name
59 AND tc.constraint_type = 'PRIMARY KEY'
60 LEFT JOIN information_schema.key_column_usage kcu
61 ON kcu.constraint_name = tc.constraint_name
62 AND kcu.constraint_schema = tc.constraint_schema
63 AND kcu.column_name = c.column_name
64 WHERE c.table_schema = ANY($1)
65 ORDER BY c.table_schema, c.table_name, c.ordinal_position
66 "#,
67 )
68 .bind(schemas)
69 .fetch_all(pool)
70 .await?;
71
72 let mut tables: Vec<TableInfo> = Vec::new();
73 let mut current_key: Option<(String, String)> = None;
74
75 for (schema, table, col_name, data_type, udt_name, nullable, ordinal, is_pk, column_default) in rows {
76 let key = (schema.clone(), table.clone());
77 if current_key.as_ref() != Some(&key) {
78 current_key = Some(key);
79 tables.push(TableInfo {
80 schema_name: schema.clone(),
81 name: table.clone(),
82 columns: Vec::new(),
83 });
84 }
85 tables.last_mut().unwrap().columns.push(ColumnInfo {
86 name: col_name,
87 data_type,
88 udt_name,
89 is_nullable: nullable == "YES",
90 is_primary_key: is_pk,
91 ordinal_position: ordinal,
92 schema_name: schema,
93 column_default,
94 });
95 }
96
97 Ok(tables)
98}
99
100async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
101 let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32, Option<String>)>(
102 r#"
103 SELECT
104 c.table_schema,
105 c.table_name,
106 c.column_name,
107 c.data_type,
108 COALESCE(c.udt_name, c.data_type) as udt_name,
109 c.is_nullable,
110 c.ordinal_position,
111 c.column_default
112 FROM information_schema.columns c
113 JOIN information_schema.tables t
114 ON t.table_schema = c.table_schema
115 AND t.table_name = c.table_name
116 AND t.table_type = 'VIEW'
117 WHERE c.table_schema = ANY($1)
118 ORDER BY c.table_schema, c.table_name, c.ordinal_position
119 "#,
120 )
121 .bind(schemas)
122 .fetch_all(pool)
123 .await?;
124
125 let mut views: Vec<TableInfo> = Vec::new();
126 let mut current_key: Option<(String, String)> = None;
127
128 for (schema, table, col_name, data_type, udt_name, nullable, ordinal, column_default) in rows {
129 let key = (schema.clone(), table.clone());
130 if current_key.as_ref() != Some(&key) {
131 current_key = Some(key);
132 views.push(TableInfo {
133 schema_name: schema.clone(),
134 name: table.clone(),
135 columns: Vec::new(),
136 });
137 }
138 views.last_mut().unwrap().columns.push(ColumnInfo {
139 name: col_name,
140 data_type,
141 udt_name,
142 is_nullable: nullable == "YES",
143 is_primary_key: false,
144 ordinal_position: ordinal,
145 schema_name: schema,
146 column_default,
147 });
148 }
149
150 Ok(views)
151}
152
153struct ViewColumnNullability {
154 view_schema: String,
155 view_name: String,
156 source_column_name: String,
157 source_not_null: bool,
158}
159
160async fn fetch_view_column_nullability(
161 pool: &PgPool,
162 schemas: &[String],
163) -> Result<Vec<ViewColumnNullability>> {
164 let rows = sqlx::query_as::<_, (String, String, String, bool)>(
165 r#"
166 SELECT DISTINCT
167 v_ns.nspname AS view_schema,
168 v.relname AS view_name,
169 src_attr.attname AS source_column_name,
170 src_attr.attnotnull AS source_not_null
171 FROM pg_class v
172 JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
173 JOIN pg_rewrite rw ON rw.ev_class = v.oid
174 JOIN pg_depend d ON d.objid = rw.oid
175 AND d.classid = 'pg_rewrite'::regclass
176 AND d.refobjsubid > 0
177 AND d.deptype = 'n'
178 JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
179 AND src_attr.attnum = d.refobjsubid
180 AND NOT src_attr.attisdropped
181 WHERE v_ns.nspname = ANY($1)
182 AND v.relkind = 'v'
183 "#,
184 )
185 .bind(schemas)
186 .fetch_all(pool)
187 .await?;
188
189 Ok(rows
190 .into_iter()
191 .map(
192 |(view_schema, view_name, source_column_name, source_not_null)| {
193 ViewColumnNullability {
194 view_schema,
195 view_name,
196 source_column_name,
197 source_not_null,
198 }
199 },
200 )
201 .collect())
202}
203
204fn resolve_view_nullability(
205 views: &mut [TableInfo],
206 nullability_info: &[ViewColumnNullability],
207) {
208 let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
210 for info in nullability_info {
211 lookup
212 .entry((&info.view_schema, &info.view_name, &info.source_column_name))
213 .or_default()
214 .push(info.source_not_null);
215 }
216
217 for view in views.iter_mut() {
218 for col in view.columns.iter_mut() {
219 if let Some(not_null_flags) = lookup.get(&(
220 view.schema_name.as_str(),
221 view.name.as_str(),
222 col.name.as_str(),
223 )) {
224 if !not_null_flags.is_empty() && not_null_flags.iter().all(|&nn| nn) {
226 col.is_nullable = false;
227 }
228 }
229 }
230 }
231}
232
233async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
234 let rows = sqlx::query_as::<_, (String, String, String)>(
235 r#"
236 SELECT
237 n.nspname AS schema_name,
238 t.typname AS enum_name,
239 e.enumlabel AS variant
240 FROM pg_catalog.pg_type t
241 JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
242 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
243 WHERE n.nspname = ANY($1)
244 ORDER BY n.nspname, t.typname, e.enumsortorder
245 "#,
246 )
247 .bind(schemas)
248 .fetch_all(pool)
249 .await?;
250
251 let mut enums: Vec<EnumInfo> = Vec::new();
252 let mut current_key: Option<(String, String)> = None;
253
254 for (schema, name, variant) in rows {
255 let key = (schema.clone(), name.clone());
256 if current_key.as_ref() != Some(&key) {
257 current_key = Some(key);
258 enums.push(EnumInfo {
259 schema_name: schema,
260 name,
261 variants: Vec::new(),
262 default_variant: None,
263 });
264 }
265 enums.last_mut().unwrap().variants.push(variant);
266 }
267
268 Ok(enums)
269}
270
271async fn fetch_composite_types(
272 pool: &PgPool,
273 schemas: &[String],
274) -> Result<Vec<CompositeTypeInfo>> {
275 let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
276 r#"
277 SELECT
278 n.nspname AS schema_name,
279 t.typname AS type_name,
280 a.attname AS field_name,
281 COALESCE(ft.typname, '') AS field_type,
282 CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
283 a.attnum AS ordinal
284 FROM pg_catalog.pg_type t
285 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
286 JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
287 JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
288 JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
289 WHERE t.typtype = 'c'
290 AND n.nspname = ANY($1)
291 AND NOT EXISTS (
292 SELECT 1 FROM information_schema.tables it
293 WHERE it.table_schema = n.nspname AND it.table_name = t.typname
294 )
295 ORDER BY n.nspname, t.typname, a.attnum
296 "#,
297 )
298 .bind(schemas)
299 .fetch_all(pool)
300 .await?;
301
302 let mut composites: Vec<CompositeTypeInfo> = Vec::new();
303 let mut current_key: Option<(String, String)> = None;
304
305 for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
306 let key = (schema.clone(), type_name.clone());
307 if current_key.as_ref() != Some(&key) {
308 current_key = Some(key);
309 composites.push(CompositeTypeInfo {
310 schema_name: schema.clone(),
311 name: type_name,
312 fields: Vec::new(),
313 });
314 }
315 composites.last_mut().unwrap().fields.push(ColumnInfo {
316 name: field_name,
317 data_type: field_type.clone(),
318 udt_name: field_type,
319 is_nullable: nullable == "YES",
320 is_primary_key: false,
321 ordinal_position: ordinal,
322 schema_name: schema,
323 column_default: None,
324 });
325 }
326
327 Ok(composites)
328}
329
330async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
331 let rows = sqlx::query_as::<_, (String, String, String)>(
332 r#"
333 SELECT
334 n.nspname AS schema_name,
335 t.typname AS domain_name,
336 bt.typname AS base_type
337 FROM pg_catalog.pg_type t
338 JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
339 JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
340 WHERE t.typtype = 'd'
341 AND n.nspname = ANY($1)
342 ORDER BY n.nspname, t.typname
343 "#,
344 )
345 .bind(schemas)
346 .fetch_all(pool)
347 .await?;
348
349 Ok(rows
350 .into_iter()
351 .map(|(schema, name, base_type)| DomainInfo {
352 schema_name: schema,
353 name,
354 base_type,
355 })
356 .collect())
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
364 TableInfo {
365 schema_name: schema.to_string(),
366 name: name.to_string(),
367 columns: columns
368 .into_iter()
369 .enumerate()
370 .map(|(i, col)| ColumnInfo {
371 name: col.to_string(),
372 data_type: "text".to_string(),
373 udt_name: "text".to_string(),
374 is_nullable: true,
375 is_primary_key: false,
376 ordinal_position: i as i32,
377 schema_name: schema.to_string(),
378 column_default: None,
379 })
380 .collect(),
381 }
382 }
383
384 fn make_nullability(
385 view_schema: &str,
386 view_name: &str,
387 source_column: &str,
388 not_null: bool,
389 ) -> ViewColumnNullability {
390 ViewColumnNullability {
391 view_schema: view_schema.to_string(),
392 view_name: view_name.to_string(),
393 source_column_name: source_column.to_string(),
394 source_not_null: not_null,
395 }
396 }
397
398 #[test]
399 fn test_resolve_not_null_column() {
400 let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
401 let info = vec![
402 make_nullability("public", "my_view", "id", true),
403 make_nullability("public", "my_view", "name", true),
404 ];
405 resolve_view_nullability(&mut views, &info);
406 assert!(!views[0].columns[0].is_nullable);
407 assert!(!views[0].columns[1].is_nullable);
408 }
409
410 #[test]
411 fn test_resolve_mixed_sources() {
412 let mut views = vec![make_view("public", "my_view", vec!["id"])];
413 let info = vec![
414 make_nullability("public", "my_view", "id", true),
415 make_nullability("public", "my_view", "id", false),
416 ];
417 resolve_view_nullability(&mut views, &info);
418 assert!(views[0].columns[0].is_nullable);
419 }
420
421 #[test]
422 fn test_resolve_no_match_stays_nullable() {
423 let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
424 let info = vec![make_nullability("public", "my_view", "id", true)];
425 resolve_view_nullability(&mut views, &info);
426 assert!(views[0].columns[0].is_nullable);
427 }
428
429 #[test]
430 fn test_resolve_empty_info() {
431 let mut views = vec![make_view("public", "my_view", vec!["id"])];
432 resolve_view_nullability(&mut views, &[]);
433 assert!(views[0].columns[0].is_nullable);
434 }
435
436 #[test]
437 fn test_resolve_cross_schema() {
438 let mut views = vec![
439 make_view("public", "v1", vec!["id"]),
440 make_view("auth", "v2", vec!["id"]),
441 ];
442 let info = vec![
443 make_nullability("public", "v1", "id", true),
444 make_nullability("auth", "v2", "id", false),
445 ];
446 resolve_view_nullability(&mut views, &info);
447 assert!(!views[0].columns[0].is_nullable);
448 assert!(views[1].columns[0].is_nullable);
449 }
450}