Skip to main content

sqlx_gen/introspect/
postgres.rs

1use std::collections::HashMap;
2
3use crate::error::Result;
4use sqlx::PgPool;
5
6use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
7
8pub async fn introspect(
9    pool: &PgPool,
10    schemas: &[String],
11    include_views: bool,
12) -> Result<SchemaInfo> {
13    let tables = fetch_tables(pool, schemas).await?;
14    let mut views = if include_views {
15        fetch_views(pool, schemas).await?
16    } else {
17        Vec::new()
18    };
19
20    if !views.is_empty() {
21        let nullability_info = fetch_view_column_nullability(pool, schemas).await?;
22        resolve_view_nullability(&mut views, &nullability_info);
23
24        let pk_info = fetch_view_column_primary_keys(pool, schemas).await?;
25        resolve_view_primary_keys(&mut views, &pk_info);
26    }
27
28    let enums = fetch_enums(pool, schemas).await?;
29    let composite_types = fetch_composite_types(pool, schemas).await?;
30    let domains = fetch_domains(pool, schemas).await?;
31
32    Ok(SchemaInfo {
33        tables,
34        views,
35        enums,
36        composite_types,
37        domains,
38    })
39}
40
41async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
42    let rows = sqlx::query_as::<
43        _,
44        (
45            String,
46            String,
47            String,
48            String,
49            String,
50            String,
51            String,
52            i32,
53            bool,
54            Option<String>,
55        ),
56    >(
57        r#"
58        SELECT
59            c.table_schema,
60            c.table_name,
61            c.column_name,
62            c.data_type,
63            COALESCE(c.udt_name, c.data_type) as udt_name,
64            COALESCE(c.udt_schema, '') as udt_schema,
65            c.is_nullable,
66            c.ordinal_position,
67            CASE WHEN kcu.column_name IS NOT NULL THEN true ELSE false END AS is_primary_key,
68            c.column_default
69        FROM information_schema.columns c
70        JOIN information_schema.tables t
71            ON t.table_schema = c.table_schema
72            AND t.table_name = c.table_name
73            AND t.table_type = 'BASE TABLE'
74        LEFT JOIN information_schema.table_constraints tc
75            ON tc.table_schema = c.table_schema
76            AND tc.table_name = c.table_name
77            AND tc.constraint_type = 'PRIMARY KEY'
78        LEFT JOIN information_schema.key_column_usage kcu
79            ON kcu.constraint_name = tc.constraint_name
80            AND kcu.constraint_schema = tc.constraint_schema
81            AND kcu.column_name = c.column_name
82        WHERE c.table_schema = ANY($1)
83        ORDER BY c.table_schema, c.table_name, c.ordinal_position
84        "#,
85    )
86    .bind(schemas)
87    .fetch_all(pool)
88    .await?;
89
90    let mut tables: Vec<TableInfo> = Vec::new();
91    let mut current_key: Option<(String, String)> = None;
92
93    for (
94        schema,
95        table,
96        col_name,
97        data_type,
98        udt_name,
99        udt_schema,
100        nullable,
101        ordinal,
102        is_pk,
103        column_default,
104    ) in rows
105    {
106        let key = (schema.clone(), table.clone());
107        if current_key.as_ref() != Some(&key) {
108            current_key = Some(key);
109            tables.push(TableInfo {
110                schema_name: schema.clone(),
111                name: table.clone(),
112                columns: Vec::new(),
113            });
114        }
115        let last = tables.last_mut().ok_or_else(|| {
116            crate::error::Error::Config(
117                "Internal sqlx-gen bug: tables vector empty after push".to_string(),
118            )
119        })?;
120        last.columns.push(ColumnInfo {
121            name: col_name,
122            data_type,
123            udt_name,
124            udt_schema: if udt_schema.is_empty() {
125                None
126            } else {
127                Some(udt_schema)
128            },
129            is_nullable: nullable == "YES",
130            is_primary_key: is_pk,
131            ordinal_position: ordinal,
132            schema_name: schema,
133            column_default,
134        });
135    }
136
137    Ok(tables)
138}
139
140async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
141    let rows = sqlx::query_as::<
142        _,
143        (
144            String,
145            String,
146            String,
147            String,
148            String,
149            String,
150            String,
151            i32,
152            Option<String>,
153        ),
154    >(
155        r#"
156        SELECT
157            c.table_schema,
158            c.table_name,
159            c.column_name,
160            c.data_type,
161            COALESCE(c.udt_name, c.data_type) as udt_name,
162            COALESCE(c.udt_schema, '') as udt_schema,
163            c.is_nullable,
164            c.ordinal_position,
165            c.column_default
166        FROM information_schema.columns c
167        JOIN information_schema.tables t
168            ON t.table_schema = c.table_schema
169            AND t.table_name = c.table_name
170            AND t.table_type = 'VIEW'
171        WHERE c.table_schema = ANY($1)
172        ORDER BY c.table_schema, c.table_name, c.ordinal_position
173        "#,
174    )
175    .bind(schemas)
176    .fetch_all(pool)
177    .await?;
178
179    let mut views: Vec<TableInfo> = Vec::new();
180    let mut current_key: Option<(String, String)> = None;
181
182    for (
183        schema,
184        table,
185        col_name,
186        data_type,
187        udt_name,
188        udt_schema,
189        nullable,
190        ordinal,
191        column_default,
192    ) in rows
193    {
194        let key = (schema.clone(), table.clone());
195        if current_key.as_ref() != Some(&key) {
196            current_key = Some(key);
197            views.push(TableInfo {
198                schema_name: schema.clone(),
199                name: table.clone(),
200                columns: Vec::new(),
201            });
202        }
203        let last = views.last_mut().ok_or_else(|| {
204            crate::error::Error::Config(
205                "Internal sqlx-gen bug: views vector empty after push".to_string(),
206            )
207        })?;
208        last.columns.push(ColumnInfo {
209            name: col_name,
210            data_type,
211            udt_name,
212            udt_schema: if udt_schema.is_empty() {
213                None
214            } else {
215                Some(udt_schema)
216            },
217            is_nullable: nullable == "YES",
218            is_primary_key: false,
219            ordinal_position: ordinal,
220            schema_name: schema,
221            column_default,
222        });
223    }
224
225    Ok(views)
226}
227
228struct ViewColumnNullability {
229    view_schema: String,
230    view_name: String,
231    source_column_name: String,
232    source_not_null: bool,
233}
234
235async fn fetch_view_column_nullability(
236    pool: &PgPool,
237    schemas: &[String],
238) -> Result<Vec<ViewColumnNullability>> {
239    let rows = sqlx::query_as::<_, (String, String, String, bool)>(
240        r#"
241        SELECT DISTINCT
242            v_ns.nspname AS view_schema,
243            v.relname AS view_name,
244            src_attr.attname AS source_column_name,
245            src_attr.attnotnull AS source_not_null
246        FROM pg_class v
247        JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
248        JOIN pg_rewrite rw ON rw.ev_class = v.oid
249        JOIN pg_depend d ON d.objid = rw.oid
250            AND d.classid = 'pg_rewrite'::regclass
251            AND d.refobjsubid > 0
252            AND d.deptype = 'n'
253        JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
254            AND src_attr.attnum = d.refobjsubid
255            AND NOT src_attr.attisdropped
256        WHERE v_ns.nspname = ANY($1)
257          AND v.relkind = 'v'
258        "#,
259    )
260    .bind(schemas)
261    .fetch_all(pool)
262    .await?;
263
264    Ok(rows
265        .into_iter()
266        .map(
267            |(view_schema, view_name, source_column_name, source_not_null)| ViewColumnNullability {
268                view_schema,
269                view_name,
270                source_column_name,
271                source_not_null,
272            },
273        )
274        .collect())
275}
276
277fn resolve_view_nullability(views: &mut [TableInfo], nullability_info: &[ViewColumnNullability]) {
278    // Build lookup: (view_schema, view_name, column_name) -> Vec<is_not_null>
279    let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
280    for info in nullability_info {
281        lookup
282            .entry((&info.view_schema, &info.view_name, &info.source_column_name))
283            .or_default()
284            .push(info.source_not_null);
285    }
286
287    for view in views.iter_mut() {
288        for col in view.columns.iter_mut() {
289            if let Some(not_null_flags) = lookup.get(&(
290                view.schema_name.as_str(),
291                view.name.as_str(),
292                col.name.as_str(),
293            )) {
294                // Only mark as non-nullable if ALL source columns are NOT NULL
295                if !not_null_flags.is_empty() && not_null_flags.iter().all(|&nn| nn) {
296                    col.is_nullable = false;
297                }
298            }
299        }
300    }
301}
302
303struct ViewColumnPrimaryKey {
304    view_schema: String,
305    view_name: String,
306    source_column_name: String,
307    source_is_pk: bool,
308}
309
310async fn fetch_view_column_primary_keys(
311    pool: &PgPool,
312    schemas: &[String],
313) -> Result<Vec<ViewColumnPrimaryKey>> {
314    let rows = sqlx::query_as::<_, (String, String, String, bool)>(
315        r#"
316        SELECT DISTINCT
317            v_ns.nspname AS view_schema,
318            v.relname AS view_name,
319            src_attr.attname AS source_column_name,
320            COALESCE(
321                EXISTS (
322                    SELECT 1
323                    FROM pg_constraint con
324                    WHERE con.conrelid = src_attr.attrelid
325                      AND con.contype = 'p'
326                      AND src_attr.attnum = ANY(con.conkey)
327                ),
328                false
329            ) AS source_is_pk
330        FROM pg_class v
331        JOIN pg_namespace v_ns ON v_ns.oid = v.relnamespace
332        JOIN pg_rewrite rw ON rw.ev_class = v.oid
333        JOIN pg_depend d ON d.objid = rw.oid
334            AND d.classid = 'pg_rewrite'::regclass
335            AND d.refobjsubid > 0
336            AND d.deptype = 'n'
337        JOIN pg_attribute src_attr ON src_attr.attrelid = d.refobjid
338            AND src_attr.attnum = d.refobjsubid
339            AND NOT src_attr.attisdropped
340        WHERE v_ns.nspname = ANY($1)
341          AND v.relkind = 'v'
342        "#,
343    )
344    .bind(schemas)
345    .fetch_all(pool)
346    .await?;
347
348    Ok(rows
349        .into_iter()
350        .map(
351            |(view_schema, view_name, source_column_name, source_is_pk)| ViewColumnPrimaryKey {
352                view_schema,
353                view_name,
354                source_column_name,
355                source_is_pk,
356            },
357        )
358        .collect())
359}
360
361fn resolve_view_primary_keys(views: &mut [TableInfo], pk_info: &[ViewColumnPrimaryKey]) {
362    // Build lookup: (view_schema, view_name, column_name) -> Vec<is_pk>
363    let mut lookup: HashMap<(&str, &str, &str), Vec<bool>> = HashMap::new();
364    for info in pk_info {
365        lookup
366            .entry((&info.view_schema, &info.view_name, &info.source_column_name))
367            .or_default()
368            .push(info.source_is_pk);
369    }
370
371    for view in views.iter_mut() {
372        for col in view.columns.iter_mut() {
373            if let Some(pk_flags) = lookup.get(&(
374                view.schema_name.as_str(),
375                view.name.as_str(),
376                col.name.as_str(),
377            )) {
378                // Only mark as PK if ALL source columns are PKs
379                if !pk_flags.is_empty() && pk_flags.iter().all(|&pk| pk) {
380                    col.is_primary_key = true;
381                }
382            }
383        }
384    }
385}
386
387async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
388    let rows = sqlx::query_as::<_, (String, String, String)>(
389        r#"
390        SELECT
391            n.nspname AS schema_name,
392            t.typname AS enum_name,
393            e.enumlabel AS variant
394        FROM pg_catalog.pg_type t
395        JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
396        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
397        WHERE n.nspname = ANY($1)
398        ORDER BY n.nspname, t.typname, e.enumsortorder
399        "#,
400    )
401    .bind(schemas)
402    .fetch_all(pool)
403    .await?;
404
405    let mut enums: Vec<EnumInfo> = Vec::new();
406    let mut current_key: Option<(String, String)> = None;
407
408    for (schema, name, variant) in rows {
409        let key = (schema.clone(), name.clone());
410        if current_key.as_ref() != Some(&key) {
411            current_key = Some(key);
412            enums.push(EnumInfo {
413                schema_name: schema,
414                name,
415                variants: Vec::new(),
416                default_variant: None,
417            });
418        }
419        let last = enums.last_mut().ok_or_else(|| {
420            crate::error::Error::Config(
421                "Internal sqlx-gen bug: enums vector empty after push".to_string(),
422            )
423        })?;
424        last.variants.push(variant);
425    }
426
427    Ok(enums)
428}
429
430async fn fetch_composite_types(
431    pool: &PgPool,
432    schemas: &[String],
433) -> Result<Vec<CompositeTypeInfo>> {
434    let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
435        r#"
436        SELECT
437            n.nspname AS schema_name,
438            t.typname AS type_name,
439            a.attname AS field_name,
440            COALESCE(ft.typname, '') AS field_type,
441            CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
442            a.attnum AS ordinal
443        FROM pg_catalog.pg_type t
444        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
445        JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
446        JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
447        JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
448        WHERE t.typtype = 'c'
449            AND n.nspname = ANY($1)
450            AND NOT EXISTS (
451                SELECT 1 FROM information_schema.tables it
452                WHERE it.table_schema = n.nspname AND it.table_name = t.typname
453            )
454        ORDER BY n.nspname, t.typname, a.attnum
455        "#,
456    )
457    .bind(schemas)
458    .fetch_all(pool)
459    .await?;
460
461    let mut composites: Vec<CompositeTypeInfo> = Vec::new();
462    let mut current_key: Option<(String, String)> = None;
463
464    for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
465        let key = (schema.clone(), type_name.clone());
466        if current_key.as_ref() != Some(&key) {
467            current_key = Some(key);
468            composites.push(CompositeTypeInfo {
469                schema_name: schema.clone(),
470                name: type_name,
471                fields: Vec::new(),
472            });
473        }
474        let last = composites.last_mut().ok_or_else(|| {
475            crate::error::Error::Config(
476                "Internal sqlx-gen bug: composites vector empty after push".to_string(),
477            )
478        })?;
479        last.fields.push(ColumnInfo {
480            name: field_name,
481            data_type: field_type.clone(),
482            udt_name: field_type,
483            udt_schema: None,
484            is_nullable: nullable == "YES",
485            is_primary_key: false,
486            ordinal_position: ordinal,
487            schema_name: schema,
488            column_default: None,
489        });
490    }
491
492    Ok(composites)
493}
494
495async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
496    let rows = sqlx::query_as::<_, (String, String, String)>(
497        r#"
498        SELECT
499            n.nspname AS schema_name,
500            t.typname AS domain_name,
501            bt.typname AS base_type
502        FROM pg_catalog.pg_type t
503        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
504        JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
505        WHERE t.typtype = 'd'
506            AND n.nspname = ANY($1)
507        ORDER BY n.nspname, t.typname
508        "#,
509    )
510    .bind(schemas)
511    .fetch_all(pool)
512    .await?;
513
514    Ok(rows
515        .into_iter()
516        .map(|(schema, name, base_type)| DomainInfo {
517            schema_name: schema,
518            name,
519            base_type,
520        })
521        .collect())
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    fn make_view(schema: &str, name: &str, columns: Vec<&str>) -> TableInfo {
529        TableInfo {
530            schema_name: schema.to_string(),
531            name: name.to_string(),
532            columns: columns
533                .into_iter()
534                .enumerate()
535                .map(|(i, col)| ColumnInfo {
536                    name: col.to_string(),
537                    data_type: "text".to_string(),
538                    udt_name: "text".to_string(),
539                    is_nullable: true,
540                    is_primary_key: false,
541                    ordinal_position: i as i32,
542                    schema_name: schema.to_string(),
543                    udt_schema: None,
544                    column_default: None,
545                })
546                .collect(),
547        }
548    }
549
550    fn make_nullability(
551        view_schema: &str,
552        view_name: &str,
553        source_column: &str,
554        not_null: bool,
555    ) -> ViewColumnNullability {
556        ViewColumnNullability {
557            view_schema: view_schema.to_string(),
558            view_name: view_name.to_string(),
559            source_column_name: source_column.to_string(),
560            source_not_null: not_null,
561        }
562    }
563
564    #[test]
565    fn test_resolve_not_null_column() {
566        let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
567        let info = vec![
568            make_nullability("public", "my_view", "id", true),
569            make_nullability("public", "my_view", "name", true),
570        ];
571        resolve_view_nullability(&mut views, &info);
572        assert!(!views[0].columns[0].is_nullable);
573        assert!(!views[0].columns[1].is_nullable);
574    }
575
576    #[test]
577    fn test_resolve_mixed_sources() {
578        let mut views = vec![make_view("public", "my_view", vec!["id"])];
579        let info = vec![
580            make_nullability("public", "my_view", "id", true),
581            make_nullability("public", "my_view", "id", false),
582        ];
583        resolve_view_nullability(&mut views, &info);
584        assert!(views[0].columns[0].is_nullable);
585    }
586
587    #[test]
588    fn test_resolve_no_match_stays_nullable() {
589        let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
590        let info = vec![make_nullability("public", "my_view", "id", true)];
591        resolve_view_nullability(&mut views, &info);
592        assert!(views[0].columns[0].is_nullable);
593    }
594
595    #[test]
596    fn test_resolve_empty_info() {
597        let mut views = vec![make_view("public", "my_view", vec!["id"])];
598        resolve_view_nullability(&mut views, &[]);
599        assert!(views[0].columns[0].is_nullable);
600    }
601
602    #[test]
603    fn test_resolve_cross_schema() {
604        let mut views = vec![
605            make_view("public", "v1", vec!["id"]),
606            make_view("auth", "v2", vec!["id"]),
607        ];
608        let info = vec![
609            make_nullability("public", "v1", "id", true),
610            make_nullability("auth", "v2", "id", false),
611        ];
612        resolve_view_nullability(&mut views, &info);
613        assert!(!views[0].columns[0].is_nullable);
614        assert!(views[1].columns[0].is_nullable);
615    }
616
617    // --- resolve_view_primary_keys tests ---
618
619    fn make_pk_info(
620        view_schema: &str,
621        view_name: &str,
622        source_column: &str,
623        is_pk: bool,
624    ) -> ViewColumnPrimaryKey {
625        ViewColumnPrimaryKey {
626            view_schema: view_schema.to_string(),
627            view_name: view_name.to_string(),
628            source_column_name: source_column.to_string(),
629            source_is_pk: is_pk,
630        }
631    }
632
633    #[test]
634    fn test_resolve_pk_column() {
635        let mut views = vec![make_view("public", "my_view", vec!["id", "name"])];
636        let info = vec![
637            make_pk_info("public", "my_view", "id", true),
638            make_pk_info("public", "my_view", "name", false),
639        ];
640        resolve_view_primary_keys(&mut views, &info);
641        assert!(views[0].columns[0].is_primary_key);
642        assert!(!views[0].columns[1].is_primary_key);
643    }
644
645    #[test]
646    fn test_resolve_pk_mixed_sources() {
647        let mut views = vec![make_view("public", "my_view", vec!["id"])];
648        let info = vec![
649            make_pk_info("public", "my_view", "id", true),
650            make_pk_info("public", "my_view", "id", false),
651        ];
652        resolve_view_primary_keys(&mut views, &info);
653        assert!(!views[0].columns[0].is_primary_key);
654    }
655
656    #[test]
657    fn test_resolve_pk_no_match() {
658        let mut views = vec![make_view("public", "my_view", vec!["computed_col"])];
659        let info = vec![make_pk_info("public", "my_view", "id", true)];
660        resolve_view_primary_keys(&mut views, &info);
661        assert!(!views[0].columns[0].is_primary_key);
662    }
663
664    #[test]
665    fn test_resolve_pk_empty_info() {
666        let mut views = vec![make_view("public", "my_view", vec!["id"])];
667        resolve_view_primary_keys(&mut views, &[]);
668        assert!(!views[0].columns[0].is_primary_key);
669    }
670
671    #[test]
672    fn test_resolve_pk_cross_schema() {
673        let mut views = vec![
674            make_view("public", "v1", vec!["id"]),
675            make_view("auth", "v2", vec!["id"]),
676        ];
677        let info = vec![
678            make_pk_info("public", "v1", "id", true),
679            make_pk_info("auth", "v2", "id", false),
680        ];
681        resolve_view_primary_keys(&mut views, &info);
682        assert!(views[0].columns[0].is_primary_key);
683        assert!(!views[1].columns[0].is_primary_key);
684    }
685}