Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword, rust_type_name_for};
10use crate::introspect::{CompositeTypeInfo, SchemaInfo};
11use crate::typemap;
12
13pub fn generate_composite(
14    composite: &CompositeTypeInfo,
15    db_kind: DatabaseKind,
16    schema_info: &SchemaInfo,
17    extra_derives: &[String],
18    type_overrides: &HashMap<String, String>,
19    time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21    let mut imports = BTreeSet::new();
22    for imp in imports_for_derives(extra_derives) {
23        imports.insert(imp);
24    }
25    let rust_name = rust_type_name_for(schema_info, &composite.schema_name, &composite.name);
26    let struct_name = format_ident!("{}", rust_name);
27    let search_path_doc = if db_kind == DatabaseKind::Postgres
28        && !crate::codegen::is_default_schema(&composite.schema_name)
29    {
30        Some(format!(
31            "Lives in PostgreSQL schema `{schema}`. The sqlx connection must \
32             include `{schema}` in its search_path so PG resolves the \
33             unqualified `type_name = \"{name}\"` to this composite.",
34            schema = composite.schema_name,
35            name = composite.name,
36        ))
37    } else {
38        None
39    };
40
41    let doc = format!(
42        "Composite type: {}.{}",
43        composite.schema_name, composite.name
44    );
45
46    imports.insert("use serde::{Serialize, Deserialize};".to_string());
47    imports.insert("use sqlx_gen::SqlxGen;".to_string());
48    let mut derive_tokens = vec![
49        quote! { Debug },
50        quote! { Clone },
51        quote! { PartialEq },
52        quote! { Eq },
53        quote! { Serialize },
54        quote! { Deserialize },
55        quote! { sqlx::Type },
56        quote! { SqlxGen },
57    ];
58    for d in extra_derives {
59        let ident = format_ident!("{}", d);
60        derive_tokens.push(quote! { #ident });
61    }
62
63    // Always unqualified — sqlx 0.8's PgTypeInfo::with_name does not accept "schema.type"
64    // and emitting it triggers runtime decode errors. Non-public schemas require the
65    // connection's `search_path` to include the schema.
66    let pg_name = &composite.name;
67    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
68
69    let fields: Vec<TokenStream> = composite
70        .fields
71        .iter()
72        .map(|col| {
73            let rust_type =
74                typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
75            if let Some(imp) = &rust_type.needs_import {
76                imports.insert(imp.clone());
77            }
78
79            let field_name_snake = col.name.to_snake_case();
80            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
81                let prefix = singularize(&composite.name).to_snake_case();
82                let prefixed = format!("{}_{}", prefix, field_name_snake);
83                (prefixed, true)
84            } else {
85                let changed = field_name_snake != col.name;
86                (field_name_snake, changed)
87            };
88
89            let field_ident = format_ident!("{}", effective_name);
90            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
91                let fallback = format_ident!("String");
92                quote! { #fallback }
93            });
94
95            let rename = if needs_rename {
96                let original = &col.name;
97                quote! { #[sqlx(rename = #original)] }
98            } else {
99                quote! {}
100            };
101
102            quote! {
103                #rename
104                pub #field_ident: #type_tokens,
105            }
106        })
107        .collect();
108
109    // `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]` auto-generates
110    // `impl PgHasArrayType` returning `_x`. Emitting a second impl triggers
111    // E0119 in the user's crate.
112    let _ = db_kind;
113
114    let search_path_doc_tokens = match &search_path_doc {
115        Some(m) => quote! { #[doc = #m] },
116        None => quote! {},
117    };
118    let tokens = quote! {
119        #[doc = #doc]
120        #search_path_doc_tokens
121        #[derive(#(#derive_tokens),*)]
122        #[sqlx_gen(kind = "composite")]
123        #type_attr
124        pub struct #struct_name {
125            #(#fields)*
126        }
127    };
128
129    (tokens, imports)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::codegen::parse_and_format;
136    use crate::introspect::ColumnInfo;
137
138    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
139        CompositeTypeInfo {
140            schema_name: "public".to_string(),
141            name: name.to_string(),
142            fields,
143        }
144    }
145
146    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
147        ColumnInfo {
148            name: name.to_string(),
149            data_type: udt_name.to_string(),
150            udt_name: udt_name.to_string(),
151            is_nullable: nullable,
152            is_primary_key: false,
153            ordinal_position: 0,
154            schema_name: "public".to_string(),
155            udt_schema: None,
156            column_default: None,
157        }
158    }
159
160    fn gen(composite: &CompositeTypeInfo) -> String {
161        let schema = SchemaInfo::default();
162        let (tokens, _) = generate_composite(
163            composite,
164            DatabaseKind::Postgres,
165            &schema,
166            &[],
167            &HashMap::new(),
168            TimeCrate::Chrono,
169        );
170        parse_and_format(&tokens).unwrap()
171    }
172
173    fn gen_with(
174        composite: &CompositeTypeInfo,
175        derives: &[String],
176        overrides: &HashMap<String, String>,
177    ) -> (String, BTreeSet<String>) {
178        let schema = SchemaInfo::default();
179        let (tokens, imports) = generate_composite(
180            composite,
181            DatabaseKind::Postgres,
182            &schema,
183            derives,
184            overrides,
185            TimeCrate::Chrono,
186        );
187        (parse_and_format(&tokens).unwrap(), imports)
188    }
189
190    // --- basic structure ---
191
192    #[test]
193    fn test_simple_composite() {
194        let c = make_composite(
195            "address",
196            vec![
197                make_field("street", "text", false),
198                make_field("city", "text", false),
199            ],
200        );
201        let code = gen(&c);
202        assert!(code.contains("pub street: String"));
203        assert!(code.contains("pub city: String"));
204    }
205
206    #[test]
207    fn test_name_pascal_case() {
208        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
209        let code = gen(&c);
210        assert!(code.contains("pub struct GeoPoint"));
211    }
212
213    #[test]
214    fn test_doc_comment() {
215        let c = make_composite("address", vec![make_field("x", "text", false)]);
216        let code = gen(&c);
217        assert!(code.contains("Composite type: public.address"));
218    }
219
220    #[test]
221    fn test_sqlx_type_name() {
222        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
223        let code = gen(&c);
224        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
225    }
226
227    #[test]
228    fn test_does_not_emit_manual_pg_has_array_type_impl() {
229        // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this
230        // impl when `type_name` is set, so emitting our own conflicted.
231        let c = make_composite("address", vec![make_field("street", "text", false)]);
232        let code = gen(&c);
233        assert!(
234            !code.contains("PgHasArrayType"),
235            "must not emit a manual PgHasArrayType impl, got:\n{}",
236            code
237        );
238    }
239
240    #[test]
241    fn test_non_public_schema_type_name_is_unqualified() {
242        // Regression: previously emitted "geo.point" which crashes sqlx 0.8 at runtime.
243        let c = CompositeTypeInfo {
244            schema_name: "geo".to_string(),
245            name: "point".to_string(),
246            fields: vec![make_field("x", "float8", false)],
247        };
248        let schema = SchemaInfo::default();
249        let (tokens, _) = generate_composite(
250            &c,
251            DatabaseKind::Postgres,
252            &schema,
253            &[],
254            &HashMap::new(),
255            TimeCrate::Chrono,
256        );
257        let code = parse_and_format(&tokens).unwrap();
258        assert!(
259            code.contains("sqlx(type_name = \"point\")"),
260            "type_name must be unqualified for sqlx 0.8, got:\n{}",
261            code
262        );
263        assert!(!code.contains("\"geo.point\""));
264    }
265
266    #[test]
267    fn test_public_schema_not_qualified() {
268        let c = make_composite("address", vec![make_field("x", "text", false)]);
269        let code = gen(&c);
270        assert!(code.contains("sqlx(type_name = \"address\")"));
271        // type_name should NOT be schema-qualified for public schema
272        assert!(!code.contains("type_name = \"public.address\""));
273    }
274
275    // --- fields ---
276
277    #[test]
278    fn test_nullable_field() {
279        let c = make_composite("address", vec![make_field("zip", "text", true)]);
280        let code = gen(&c);
281        assert!(code.contains("Option<String>"));
282    }
283
284    #[test]
285    fn test_non_nullable_field() {
286        let c = make_composite("address", vec![make_field("city", "text", false)]);
287        let code = gen(&c);
288        assert!(code.contains("pub city: String"));
289        assert!(!code.contains("Option"));
290    }
291
292    #[test]
293    fn test_keyword_field_prefixed() {
294        let c = make_composite("item", vec![make_field("type", "text", false)]);
295        let code = gen(&c);
296        assert!(code.contains("pub item_type: String"));
297        assert!(code.contains("sqlx(rename = \"type\")"));
298    }
299
300    // --- rename ---
301
302    #[test]
303    fn test_camel_case_field_renamed() {
304        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
305        let code = gen(&c);
306        assert!(code.contains("pub street_name: String"));
307        assert!(code.contains("sqlx(rename = \"StreetName\")"));
308    }
309
310    #[test]
311    fn test_snake_case_field_no_rename() {
312        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
313        let code = gen(&c);
314        assert!(code.contains("pub street_name: String"));
315        assert!(!code.contains("sqlx(rename"));
316    }
317
318    // --- types ---
319
320    #[test]
321    fn test_int4_field() {
322        let c = make_composite("data", vec![make_field("count", "int4", false)]);
323        let code = gen(&c);
324        assert!(code.contains("pub count: i32"));
325    }
326
327    #[test]
328    fn test_uuid_field_import() {
329        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
330        let (_, imports) = gen_with(&c, &[], &HashMap::new());
331        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
332    }
333
334    #[test]
335    fn test_text_field() {
336        let c = make_composite("data", vec![make_field("label", "text", false)]);
337        let code = gen(&c);
338        assert!(code.contains("pub label: String"));
339    }
340
341    // --- derives ---
342
343    #[test]
344    fn test_default_derives() {
345        let c = make_composite("data", vec![make_field("x", "text", false)]);
346        let code = gen(&c);
347        assert!(code.contains("Debug"));
348        assert!(code.contains("Clone"));
349        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
350    }
351
352    #[test]
353    fn test_extra_derive() {
354        let c = make_composite("data", vec![make_field("x", "text", false)]);
355        let derives = vec!["Serialize".to_string()];
356        let (code, _) = gen_with(&c, &derives, &HashMap::new());
357        assert!(code.contains("Serialize"));
358    }
359
360    // --- overrides ---
361
362    #[test]
363    fn test_type_override() {
364        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
365        let mut overrides = HashMap::new();
366        overrides.insert("jsonb".to_string(), "MyJson".to_string());
367        let (code, _) = gen_with(&c, &[], &overrides);
368        assert!(code.contains("pub payload: MyJson"));
369    }
370}