Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToSnakeCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword, rust_type_name_for};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13    composite: &CompositeTypeInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    time_crate: TimeCrate,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let rust_name = rust_type_name_for(schema_info, &composite.schema_name, &composite.name);
25    let struct_name = format_ident!("{}", rust_name);
26    let search_path_doc = if db_kind == DatabaseKind::Postgres
27        && !crate::codegen::is_default_schema(&composite.schema_name)
28    {
29        Some(format!(
30            "Lives in PostgreSQL schema `{schema}`. The sqlx connection must \
31             include `{schema}` in its search_path so PG resolves the \
32             unqualified `type_name = \"{name}\"` to this composite.",
33            schema = composite.schema_name,
34            name = composite.name,
35        ))
36    } else {
37        None
38    };
39
40    let doc = format!(
41        "Composite type: {}.{}",
42        composite.schema_name, composite.name
43    );
44
45    imports.insert("use serde::{Serialize, Deserialize};".to_string());
46    imports.insert("use sqlx_gen::SqlxGen;".to_string());
47    let mut derive_tokens = vec![
48        quote! { Debug },
49        quote! { Clone },
50        quote! { PartialEq },
51        quote! { Eq },
52        quote! { Serialize },
53        quote! { Deserialize },
54        quote! { sqlx::Type },
55        quote! { SqlxGen },
56    ];
57    for d in extra_derives {
58        let ident = format_ident!("{}", d);
59        derive_tokens.push(quote! { #ident });
60    }
61
62    // Always unqualified — sqlx 0.8's PgTypeInfo::with_name does not accept "schema.type"
63    // and emitting it triggers runtime decode errors. Non-public schemas require the
64    // connection's `search_path` to include the schema.
65    let pg_name = &composite.name;
66    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
67
68    let fields: Vec<TokenStream> = composite
69        .fields
70        .iter()
71        .map(|col| {
72            let rust_type =
73                typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
74            if let Some(imp) = &rust_type.needs_import {
75                imports.insert(imp.clone());
76            }
77
78            let field_name_snake = col.name.to_snake_case();
79            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
80                let prefixed = format!("{}_{}", composite.name.to_snake_case(), field_name_snake);
81                (prefixed, true)
82            } else {
83                let changed = field_name_snake != col.name;
84                (field_name_snake, changed)
85            };
86
87            let field_ident = format_ident!("{}", effective_name);
88            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
89                let fallback = format_ident!("String");
90                quote! { #fallback }
91            });
92
93            let rename = if needs_rename {
94                let original = &col.name;
95                quote! { #[sqlx(rename = #original)] }
96            } else {
97                quote! {}
98            };
99
100            quote! {
101                #rename
102                pub #field_ident: #type_tokens,
103            }
104        })
105        .collect();
106
107    // `#[derive(sqlx::Type)]` with `#[sqlx(type_name = "x")]` auto-generates
108    // `impl PgHasArrayType` returning `_x`. Emitting a second impl triggers
109    // E0119 in the user's crate.
110    let _ = db_kind;
111
112    let search_path_doc_tokens = match &search_path_doc {
113        Some(m) => quote! { #[doc = #m] },
114        None => quote! {},
115    };
116    let tokens = quote! {
117        #[doc = #doc]
118        #search_path_doc_tokens
119        #[derive(#(#derive_tokens),*)]
120        #[sqlx_gen(kind = "composite")]
121        #type_attr
122        pub struct #struct_name {
123            #(#fields)*
124        }
125    };
126
127    (tokens, imports)
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::codegen::parse_and_format;
134    use crate::introspect::ColumnInfo;
135
136    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
137        CompositeTypeInfo {
138            schema_name: "public".to_string(),
139            name: name.to_string(),
140            fields,
141        }
142    }
143
144    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
145        ColumnInfo {
146            name: name.to_string(),
147            data_type: udt_name.to_string(),
148            udt_name: udt_name.to_string(),
149            is_nullable: nullable,
150            is_primary_key: false,
151            ordinal_position: 0,
152            schema_name: "public".to_string(),
153            udt_schema: None,
154            column_default: None,
155        }
156    }
157
158    fn gen(composite: &CompositeTypeInfo) -> String {
159        let schema = SchemaInfo::default();
160        let (tokens, _) = generate_composite(
161            composite,
162            DatabaseKind::Postgres,
163            &schema,
164            &[],
165            &HashMap::new(),
166            TimeCrate::Chrono,
167        );
168        parse_and_format(&tokens).unwrap()
169    }
170
171    fn gen_with(
172        composite: &CompositeTypeInfo,
173        derives: &[String],
174        overrides: &HashMap<String, String>,
175    ) -> (String, BTreeSet<String>) {
176        let schema = SchemaInfo::default();
177        let (tokens, imports) = generate_composite(
178            composite,
179            DatabaseKind::Postgres,
180            &schema,
181            derives,
182            overrides,
183            TimeCrate::Chrono,
184        );
185        (parse_and_format(&tokens).unwrap(), imports)
186    }
187
188    // --- basic structure ---
189
190    #[test]
191    fn test_simple_composite() {
192        let c = make_composite(
193            "address",
194            vec![
195                make_field("street", "text", false),
196                make_field("city", "text", false),
197            ],
198        );
199        let code = gen(&c);
200        assert!(code.contains("pub street: String"));
201        assert!(code.contains("pub city: String"));
202    }
203
204    #[test]
205    fn test_name_pascal_case() {
206        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
207        let code = gen(&c);
208        assert!(code.contains("pub struct GeoPoint"));
209    }
210
211    #[test]
212    fn test_doc_comment() {
213        let c = make_composite("address", vec![make_field("x", "text", false)]);
214        let code = gen(&c);
215        assert!(code.contains("Composite type: public.address"));
216    }
217
218    #[test]
219    fn test_sqlx_type_name() {
220        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
221        let code = gen(&c);
222        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
223    }
224
225    #[test]
226    fn test_does_not_emit_manual_pg_has_array_type_impl() {
227        // Regression for E0119 — `#[derive(sqlx::Type)]` already provides this
228        // impl when `type_name` is set, so emitting our own conflicted.
229        let c = make_composite("address", vec![make_field("street", "text", false)]);
230        let code = gen(&c);
231        assert!(
232            !code.contains("PgHasArrayType"),
233            "must not emit a manual PgHasArrayType impl, got:\n{}",
234            code
235        );
236    }
237
238    #[test]
239    fn test_non_public_schema_type_name_is_unqualified() {
240        // Regression: previously emitted "geo.point" which crashes sqlx 0.8 at runtime.
241        let c = CompositeTypeInfo {
242            schema_name: "geo".to_string(),
243            name: "point".to_string(),
244            fields: vec![make_field("x", "float8", false)],
245        };
246        let schema = SchemaInfo::default();
247        let (tokens, _) = generate_composite(
248            &c,
249            DatabaseKind::Postgres,
250            &schema,
251            &[],
252            &HashMap::new(),
253            TimeCrate::Chrono,
254        );
255        let code = parse_and_format(&tokens).unwrap();
256        assert!(
257            code.contains("sqlx(type_name = \"point\")"),
258            "type_name must be unqualified for sqlx 0.8, got:\n{}",
259            code
260        );
261        assert!(!code.contains("\"geo.point\""));
262    }
263
264    #[test]
265    fn test_public_schema_not_qualified() {
266        let c = make_composite("address", vec![make_field("x", "text", false)]);
267        let code = gen(&c);
268        assert!(code.contains("sqlx(type_name = \"address\")"));
269        // type_name should NOT be schema-qualified for public schema
270        assert!(!code.contains("type_name = \"public.address\""));
271    }
272
273    // --- fields ---
274
275    #[test]
276    fn test_nullable_field() {
277        let c = make_composite("address", vec![make_field("zip", "text", true)]);
278        let code = gen(&c);
279        assert!(code.contains("Option<String>"));
280    }
281
282    #[test]
283    fn test_non_nullable_field() {
284        let c = make_composite("address", vec![make_field("city", "text", false)]);
285        let code = gen(&c);
286        assert!(code.contains("pub city: String"));
287        assert!(!code.contains("Option"));
288    }
289
290    #[test]
291    fn test_keyword_field_prefixed() {
292        let c = make_composite("item", vec![make_field("type", "text", false)]);
293        let code = gen(&c);
294        assert!(code.contains("pub item_type: String"));
295        assert!(code.contains("sqlx(rename = \"type\")"));
296    }
297
298    // --- rename ---
299
300    #[test]
301    fn test_camel_case_field_renamed() {
302        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
303        let code = gen(&c);
304        assert!(code.contains("pub street_name: String"));
305        assert!(code.contains("sqlx(rename = \"StreetName\")"));
306    }
307
308    #[test]
309    fn test_snake_case_field_no_rename() {
310        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
311        let code = gen(&c);
312        assert!(code.contains("pub street_name: String"));
313        assert!(!code.contains("sqlx(rename"));
314    }
315
316    // --- types ---
317
318    #[test]
319    fn test_int4_field() {
320        let c = make_composite("data", vec![make_field("count", "int4", false)]);
321        let code = gen(&c);
322        assert!(code.contains("pub count: i32"));
323    }
324
325    #[test]
326    fn test_uuid_field_import() {
327        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
328        let (_, imports) = gen_with(&c, &[], &HashMap::new());
329        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
330    }
331
332    #[test]
333    fn test_text_field() {
334        let c = make_composite("data", vec![make_field("label", "text", false)]);
335        let code = gen(&c);
336        assert!(code.contains("pub label: String"));
337    }
338
339    // --- derives ---
340
341    #[test]
342    fn test_default_derives() {
343        let c = make_composite("data", vec![make_field("x", "text", false)]);
344        let code = gen(&c);
345        assert!(code.contains("Debug"));
346        assert!(code.contains("Clone"));
347        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
348    }
349
350    #[test]
351    fn test_extra_derive() {
352        let c = make_composite("data", vec![make_field("x", "text", false)]);
353        let derives = vec!["Serialize".to_string()];
354        let (code, _) = gen_with(&c, &derives, &HashMap::new());
355        assert!(code.contains("Serialize"));
356    }
357
358    // --- overrides ---
359
360    #[test]
361    fn test_type_override() {
362        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
363        let mut overrides = HashMap::new();
364        overrides.insert("jsonb".to_string(), "MyJson".to_string());
365        let (code, _) = gen_with(&c, &[], &overrides);
366        assert!(code.contains("pub payload: MyJson"));
367    }
368}