Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13    composite: &CompositeTypeInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19    let mut imports = BTreeSet::new();
20    for imp in imports_for_derives(extra_derives) {
21        imports.insert(imp);
22    }
23    let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25    let doc = format!(
26        "Composite type: {}.{}",
27        composite.schema_name, composite.name
28    );
29
30    imports.insert("use serde::{Serialize, Deserialize};".to_string());
31    imports.insert("use sqlx_gen::SqlxGen;".to_string());
32    let mut derive_tokens = vec![
33        quote! { Debug },
34        quote! { Clone },
35        quote! { PartialEq },
36        quote! { Eq },
37        quote! { Serialize },
38        quote! { Deserialize },
39        quote! { sqlx::Type },
40        quote! { SqlxGen },
41    ];
42    for d in extra_derives {
43        let ident = format_ident!("{}", d);
44        derive_tokens.push(quote! { #ident });
45    }
46
47    // Schema-qualify the type name for non-public schemas so sqlx can find the type
48    let pg_name = if composite.schema_name != "public" {
49        format!("{}.{}", composite.schema_name, composite.name)
50    } else {
51        composite.name.clone()
52    };
53    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
54
55    let fields: Vec<TokenStream> = composite
56        .fields
57        .iter()
58        .map(|col| {
59            let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
60            if let Some(imp) = &rust_type.needs_import {
61                imports.insert(imp.clone());
62            }
63
64            let field_name_snake = col.name.to_snake_case();
65            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
66                let prefixed = format!(
67                    "{}_{}",
68                    composite.name.to_snake_case(),
69                    field_name_snake
70                );
71                (prefixed, true)
72            } else {
73                let changed = field_name_snake != col.name;
74                (field_name_snake, changed)
75            };
76
77            let field_ident = format_ident!("{}", effective_name);
78            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
79                let fallback = format_ident!("String");
80                quote! { #fallback }
81            });
82
83            let rename = if needs_rename {
84                let original = &col.name;
85                quote! { #[sqlx(rename = #original)] }
86            } else {
87                quote! {}
88            };
89
90            quote! {
91                #rename
92                pub #field_ident: #type_tokens,
93            }
94        })
95        .collect();
96
97    let tokens = quote! {
98        #[doc = #doc]
99        #[derive(#(#derive_tokens),*)]
100        #[sqlx_gen(kind = "composite")]
101        #type_attr
102        pub struct #struct_name {
103            #(#fields)*
104        }
105    };
106
107    (tokens, imports)
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::codegen::parse_and_format;
114    use crate::introspect::ColumnInfo;
115
116    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
117        CompositeTypeInfo {
118            schema_name: "public".to_string(),
119            name: name.to_string(),
120            fields,
121        }
122    }
123
124    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
125        ColumnInfo {
126            name: name.to_string(),
127            data_type: udt_name.to_string(),
128            udt_name: udt_name.to_string(),
129            is_nullable: nullable,
130            is_primary_key: false,
131            ordinal_position: 0,
132            schema_name: "public".to_string(),
133            column_default: None,
134        }
135    }
136
137    fn gen(composite: &CompositeTypeInfo) -> String {
138        let schema = SchemaInfo::default();
139        let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
140        parse_and_format(&tokens)
141    }
142
143    fn gen_with(
144        composite: &CompositeTypeInfo,
145        derives: &[String],
146        overrides: &HashMap<String, String>,
147    ) -> (String, BTreeSet<String>) {
148        let schema = SchemaInfo::default();
149        let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
150        (parse_and_format(&tokens), imports)
151    }
152
153    // --- basic structure ---
154
155    #[test]
156    fn test_simple_composite() {
157        let c = make_composite("address", vec![
158            make_field("street", "text", false),
159            make_field("city", "text", false),
160        ]);
161        let code = gen(&c);
162        assert!(code.contains("pub street: String"));
163        assert!(code.contains("pub city: String"));
164    }
165
166    #[test]
167    fn test_name_pascal_case() {
168        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
169        let code = gen(&c);
170        assert!(code.contains("pub struct GeoPoint"));
171    }
172
173    #[test]
174    fn test_doc_comment() {
175        let c = make_composite("address", vec![make_field("x", "text", false)]);
176        let code = gen(&c);
177        assert!(code.contains("Composite type: public.address"));
178    }
179
180    #[test]
181    fn test_sqlx_type_name() {
182        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
183        let code = gen(&c);
184        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
185    }
186
187    #[test]
188    fn test_non_public_schema_qualified_type_name() {
189        let c = CompositeTypeInfo {
190            schema_name: "geo".to_string(),
191            name: "point".to_string(),
192            fields: vec![make_field("x", "float8", false)],
193        };
194        let schema = SchemaInfo::default();
195        let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
196        let code = parse_and_format(&tokens);
197        assert!(code.contains("sqlx(type_name = \"geo.point\")"));
198    }
199
200    #[test]
201    fn test_public_schema_not_qualified() {
202        let c = make_composite("address", vec![make_field("x", "text", false)]);
203        let code = gen(&c);
204        assert!(code.contains("sqlx(type_name = \"address\")"));
205        // type_name should NOT be schema-qualified for public schema
206        assert!(!code.contains("type_name = \"public.address\""));
207    }
208
209    // --- fields ---
210
211    #[test]
212    fn test_nullable_field() {
213        let c = make_composite("address", vec![make_field("zip", "text", true)]);
214        let code = gen(&c);
215        assert!(code.contains("Option<String>"));
216    }
217
218    #[test]
219    fn test_non_nullable_field() {
220        let c = make_composite("address", vec![make_field("city", "text", false)]);
221        let code = gen(&c);
222        assert!(code.contains("pub city: String"));
223        assert!(!code.contains("Option"));
224    }
225
226    #[test]
227    fn test_keyword_field_prefixed() {
228        let c = make_composite("item", vec![make_field("type", "text", false)]);
229        let code = gen(&c);
230        assert!(code.contains("pub item_type: String"));
231        assert!(code.contains("sqlx(rename = \"type\")"));
232    }
233
234    // --- rename ---
235
236    #[test]
237    fn test_camel_case_field_renamed() {
238        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
239        let code = gen(&c);
240        assert!(code.contains("pub street_name: String"));
241        assert!(code.contains("sqlx(rename = \"StreetName\")"));
242    }
243
244    #[test]
245    fn test_snake_case_field_no_rename() {
246        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
247        let code = gen(&c);
248        assert!(code.contains("pub street_name: String"));
249        assert!(!code.contains("sqlx(rename"));
250    }
251
252    // --- types ---
253
254    #[test]
255    fn test_int4_field() {
256        let c = make_composite("data", vec![make_field("count", "int4", false)]);
257        let code = gen(&c);
258        assert!(code.contains("pub count: i32"));
259    }
260
261    #[test]
262    fn test_uuid_field_import() {
263        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
264        let (_, imports) = gen_with(&c, &[], &HashMap::new());
265        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
266    }
267
268    #[test]
269    fn test_text_field() {
270        let c = make_composite("data", vec![make_field("label", "text", false)]);
271        let code = gen(&c);
272        assert!(code.contains("pub label: String"));
273    }
274
275    // --- derives ---
276
277    #[test]
278    fn test_default_derives() {
279        let c = make_composite("data", vec![make_field("x", "text", false)]);
280        let code = gen(&c);
281        assert!(code.contains("Debug"));
282        assert!(code.contains("Clone"));
283        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
284    }
285
286    #[test]
287    fn test_extra_derive() {
288        let c = make_composite("data", vec![make_field("x", "text", false)]);
289        let derives = vec!["Serialize".to_string()];
290        let (code, _) = gen_with(&c, &derives, &HashMap::new());
291        assert!(code.contains("Serialize"));
292    }
293
294    // --- overrides ---
295
296    #[test]
297    fn test_type_override() {
298        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
299        let mut overrides = HashMap::new();
300        overrides.insert("jsonb".to_string(), "MyJson".to_string());
301        let (code, _) = gen_with(&c, &[], &overrides);
302        assert!(code.contains("pub payload: MyJson"));
303    }
304}