Skip to main content

sqlx_gen/introspect/
postgres.rs

1use anyhow::Result;
2use sqlx::PgPool;
3
4use super::{ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo};
5
6pub async fn introspect(
7    pool: &PgPool,
8    schemas: &[String],
9    include_views: bool,
10) -> Result<SchemaInfo> {
11    let tables = fetch_tables(pool, schemas).await?;
12    let views = if include_views {
13        fetch_views(pool, schemas).await?
14    } else {
15        Vec::new()
16    };
17    let enums = fetch_enums(pool, schemas).await?;
18    let composite_types = fetch_composite_types(pool, schemas).await?;
19    let domains = fetch_domains(pool, schemas).await?;
20
21    Ok(SchemaInfo {
22        tables,
23        views,
24        enums,
25        composite_types,
26        domains,
27    })
28}
29
30async fn fetch_tables(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
31    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32)>(
32        r#"
33        SELECT
34            c.table_schema,
35            c.table_name,
36            c.column_name,
37            c.data_type,
38            COALESCE(c.udt_name, c.data_type) as udt_name,
39            c.is_nullable,
40            c.ordinal_position
41        FROM information_schema.columns c
42        JOIN information_schema.tables t
43            ON t.table_schema = c.table_schema
44            AND t.table_name = c.table_name
45            AND t.table_type = 'BASE TABLE'
46        WHERE c.table_schema = ANY($1)
47        ORDER BY c.table_schema, c.table_name, c.ordinal_position
48        "#,
49    )
50    .bind(schemas)
51    .fetch_all(pool)
52    .await?;
53
54    let mut tables: Vec<TableInfo> = Vec::new();
55    let mut current_key: Option<(String, String)> = None;
56
57    for (schema, table, col_name, data_type, udt_name, nullable, ordinal) in rows {
58        let key = (schema.clone(), table.clone());
59        if current_key.as_ref() != Some(&key) {
60            current_key = Some(key);
61            tables.push(TableInfo {
62                schema_name: schema.clone(),
63                name: table.clone(),
64                columns: Vec::new(),
65            });
66        }
67        tables.last_mut().unwrap().columns.push(ColumnInfo {
68            name: col_name,
69            data_type,
70            udt_name,
71            is_nullable: nullable == "YES",
72            ordinal_position: ordinal,
73            schema_name: schema,
74        });
75    }
76
77    Ok(tables)
78}
79
80async fn fetch_views(pool: &PgPool, schemas: &[String]) -> Result<Vec<TableInfo>> {
81    let rows = sqlx::query_as::<_, (String, String, String, String, String, String, i32)>(
82        r#"
83        SELECT
84            c.table_schema,
85            c.table_name,
86            c.column_name,
87            c.data_type,
88            COALESCE(c.udt_name, c.data_type) as udt_name,
89            c.is_nullable,
90            c.ordinal_position
91        FROM information_schema.columns c
92        JOIN information_schema.tables t
93            ON t.table_schema = c.table_schema
94            AND t.table_name = c.table_name
95            AND t.table_type = 'VIEW'
96        WHERE c.table_schema = ANY($1)
97        ORDER BY c.table_schema, c.table_name, c.ordinal_position
98        "#,
99    )
100    .bind(schemas)
101    .fetch_all(pool)
102    .await?;
103
104    let mut views: Vec<TableInfo> = Vec::new();
105    let mut current_key: Option<(String, String)> = None;
106
107    for (schema, table, col_name, data_type, udt_name, nullable, ordinal) in rows {
108        let key = (schema.clone(), table.clone());
109        if current_key.as_ref() != Some(&key) {
110            current_key = Some(key);
111            views.push(TableInfo {
112                schema_name: schema.clone(),
113                name: table.clone(),
114                columns: Vec::new(),
115            });
116        }
117        views.last_mut().unwrap().columns.push(ColumnInfo {
118            name: col_name,
119            data_type,
120            udt_name,
121            is_nullable: nullable == "YES",
122            ordinal_position: ordinal,
123            schema_name: schema,
124        });
125    }
126
127    Ok(views)
128}
129
130async fn fetch_enums(pool: &PgPool, schemas: &[String]) -> Result<Vec<EnumInfo>> {
131    let rows = sqlx::query_as::<_, (String, String, String)>(
132        r#"
133        SELECT
134            n.nspname AS schema_name,
135            t.typname AS enum_name,
136            e.enumlabel AS variant
137        FROM pg_catalog.pg_type t
138        JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid
139        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
140        WHERE n.nspname = ANY($1)
141        ORDER BY n.nspname, t.typname, e.enumsortorder
142        "#,
143    )
144    .bind(schemas)
145    .fetch_all(pool)
146    .await?;
147
148    let mut enums: Vec<EnumInfo> = Vec::new();
149    let mut current_key: Option<(String, String)> = None;
150
151    for (schema, name, variant) in rows {
152        let key = (schema.clone(), name.clone());
153        if current_key.as_ref() != Some(&key) {
154            current_key = Some(key);
155            enums.push(EnumInfo {
156                schema_name: schema,
157                name,
158                variants: Vec::new(),
159            });
160        }
161        enums.last_mut().unwrap().variants.push(variant);
162    }
163
164    Ok(enums)
165}
166
167async fn fetch_composite_types(
168    pool: &PgPool,
169    schemas: &[String],
170) -> Result<Vec<CompositeTypeInfo>> {
171    let rows = sqlx::query_as::<_, (String, String, String, String, String, i32)>(
172        r#"
173        SELECT
174            n.nspname AS schema_name,
175            t.typname AS type_name,
176            a.attname AS field_name,
177            COALESCE(ft.typname, '') AS field_type,
178            CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
179            a.attnum AS ordinal
180        FROM pg_catalog.pg_type t
181        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
182        JOIN pg_catalog.pg_class c ON c.oid = t.typrelid
183        JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid AND a.attnum > 0 AND NOT a.attisdropped
184        JOIN pg_catalog.pg_type ft ON ft.oid = a.atttypid
185        WHERE t.typtype = 'c'
186            AND n.nspname = ANY($1)
187            AND NOT EXISTS (
188                SELECT 1 FROM information_schema.tables it
189                WHERE it.table_schema = n.nspname AND it.table_name = t.typname
190            )
191        ORDER BY n.nspname, t.typname, a.attnum
192        "#,
193    )
194    .bind(schemas)
195    .fetch_all(pool)
196    .await?;
197
198    let mut composites: Vec<CompositeTypeInfo> = Vec::new();
199    let mut current_key: Option<(String, String)> = None;
200
201    for (schema, type_name, field_name, field_type, nullable, ordinal) in rows {
202        let key = (schema.clone(), type_name.clone());
203        if current_key.as_ref() != Some(&key) {
204            current_key = Some(key);
205            composites.push(CompositeTypeInfo {
206                schema_name: schema.clone(),
207                name: type_name,
208                fields: Vec::new(),
209            });
210        }
211        composites.last_mut().unwrap().fields.push(ColumnInfo {
212            name: field_name,
213            data_type: field_type.clone(),
214            udt_name: field_type,
215            is_nullable: nullable == "YES",
216            ordinal_position: ordinal,
217            schema_name: schema,
218        });
219    }
220
221    Ok(composites)
222}
223
224async fn fetch_domains(pool: &PgPool, schemas: &[String]) -> Result<Vec<DomainInfo>> {
225    let rows = sqlx::query_as::<_, (String, String, String)>(
226        r#"
227        SELECT
228            n.nspname AS schema_name,
229            t.typname AS domain_name,
230            bt.typname AS base_type
231        FROM pg_catalog.pg_type t
232        JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
233        JOIN pg_catalog.pg_type bt ON bt.oid = t.typbasetype
234        WHERE t.typtype = 'd'
235            AND n.nspname = ANY($1)
236        ORDER BY n.nspname, t.typname
237        "#,
238    )
239    .bind(schemas)
240    .fetch_all(pool)
241    .await?;
242
243    Ok(rows
244        .into_iter()
245        .map(|(schema, name, base_type)| DomainInfo {
246            schema_name: schema,
247            name,
248            base_type,
249        })
250        .collect())
251}