Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13    composite: &CompositeTypeInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    time_crate: TimeCrate,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
25
26    let doc = format!(
27        "Composite type: {}.{}",
28        composite.schema_name, composite.name
29    );
30
31    imports.insert("use serde::{Serialize, Deserialize};".to_string());
32    imports.insert("use sqlx_gen::SqlxGen;".to_string());
33    let mut derive_tokens = vec![
34        quote! { Debug },
35        quote! { Clone },
36        quote! { PartialEq },
37        quote! { Eq },
38        quote! { Serialize },
39        quote! { Deserialize },
40        quote! { sqlx::Type },
41        quote! { SqlxGen },
42    ];
43    for d in extra_derives {
44        let ident = format_ident!("{}", d);
45        derive_tokens.push(quote! { #ident });
46    }
47
48    let pg_name = &composite.name;
49    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
50
51    let fields: Vec<TokenStream> = composite
52        .fields
53        .iter()
54        .map(|col| {
55            let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate);
56            if let Some(imp) = &rust_type.needs_import {
57                imports.insert(imp.clone());
58            }
59
60            let field_name_snake = col.name.to_snake_case();
61            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
62                let prefixed = format!(
63                    "{}_{}",
64                    composite.name.to_snake_case(),
65                    field_name_snake
66                );
67                (prefixed, true)
68            } else {
69                let changed = field_name_snake != col.name;
70                (field_name_snake, changed)
71            };
72
73            let field_ident = format_ident!("{}", effective_name);
74            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
75                let fallback = format_ident!("String");
76                quote! { #fallback }
77            });
78
79            let rename = if needs_rename {
80                let original = &col.name;
81                quote! { #[sqlx(rename = #original)] }
82            } else {
83                quote! {}
84            };
85
86            quote! {
87                #rename
88                pub #field_ident: #type_tokens,
89            }
90        })
91        .collect();
92
93    let tokens = quote! {
94        #[doc = #doc]
95        #[derive(#(#derive_tokens),*)]
96        #[sqlx_gen(kind = "composite")]
97        #type_attr
98        pub struct #struct_name {
99            #(#fields)*
100        }
101    };
102
103    (tokens, imports)
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::codegen::parse_and_format;
110    use crate::introspect::ColumnInfo;
111
112    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
113        CompositeTypeInfo {
114            schema_name: "public".to_string(),
115            name: name.to_string(),
116            fields,
117        }
118    }
119
120    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
121        ColumnInfo {
122            name: name.to_string(),
123            data_type: udt_name.to_string(),
124            udt_name: udt_name.to_string(),
125            is_nullable: nullable,
126            is_primary_key: false,
127            ordinal_position: 0,
128            schema_name: "public".to_string(),
129            column_default: None,
130        }
131    }
132
133    fn gen(composite: &CompositeTypeInfo) -> String {
134        let schema = SchemaInfo::default();
135        let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono);
136        parse_and_format(&tokens)
137    }
138
139    fn gen_with(
140        composite: &CompositeTypeInfo,
141        derives: &[String],
142        overrides: &HashMap<String, String>,
143    ) -> (String, BTreeSet<String>) {
144        let schema = SchemaInfo::default();
145        let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides, TimeCrate::Chrono);
146        (parse_and_format(&tokens), imports)
147    }
148
149    // --- basic structure ---
150
151    #[test]
152    fn test_simple_composite() {
153        let c = make_composite("address", vec![
154            make_field("street", "text", false),
155            make_field("city", "text", false),
156        ]);
157        let code = gen(&c);
158        assert!(code.contains("pub street: String"));
159        assert!(code.contains("pub city: String"));
160    }
161
162    #[test]
163    fn test_name_pascal_case() {
164        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
165        let code = gen(&c);
166        assert!(code.contains("pub struct GeoPoint"));
167    }
168
169    #[test]
170    fn test_doc_comment() {
171        let c = make_composite("address", vec![make_field("x", "text", false)]);
172        let code = gen(&c);
173        assert!(code.contains("Composite type: public.address"));
174    }
175
176    #[test]
177    fn test_sqlx_type_name() {
178        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
179        let code = gen(&c);
180        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
181    }
182
183    #[test]
184    fn test_non_public_schema_qualified_type_name() {
185        let c = CompositeTypeInfo {
186            schema_name: "geo".to_string(),
187            name: "point".to_string(),
188            fields: vec![make_field("x", "float8", false)],
189        };
190        let schema = SchemaInfo::default();
191        let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), TimeCrate::Chrono);
192        let code = parse_and_format(&tokens);
193        assert!(code.contains("sqlx(type_name = \"point\")"));
194    }
195
196    #[test]
197    fn test_public_schema_not_qualified() {
198        let c = make_composite("address", vec![make_field("x", "text", false)]);
199        let code = gen(&c);
200        assert!(code.contains("sqlx(type_name = \"address\")"));
201        // type_name should NOT be schema-qualified for public schema
202        assert!(!code.contains("type_name = \"public.address\""));
203    }
204
205    // --- fields ---
206
207    #[test]
208    fn test_nullable_field() {
209        let c = make_composite("address", vec![make_field("zip", "text", true)]);
210        let code = gen(&c);
211        assert!(code.contains("Option<String>"));
212    }
213
214    #[test]
215    fn test_non_nullable_field() {
216        let c = make_composite("address", vec![make_field("city", "text", false)]);
217        let code = gen(&c);
218        assert!(code.contains("pub city: String"));
219        assert!(!code.contains("Option"));
220    }
221
222    #[test]
223    fn test_keyword_field_prefixed() {
224        let c = make_composite("item", vec![make_field("type", "text", false)]);
225        let code = gen(&c);
226        assert!(code.contains("pub item_type: String"));
227        assert!(code.contains("sqlx(rename = \"type\")"));
228    }
229
230    // --- rename ---
231
232    #[test]
233    fn test_camel_case_field_renamed() {
234        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
235        let code = gen(&c);
236        assert!(code.contains("pub street_name: String"));
237        assert!(code.contains("sqlx(rename = \"StreetName\")"));
238    }
239
240    #[test]
241    fn test_snake_case_field_no_rename() {
242        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
243        let code = gen(&c);
244        assert!(code.contains("pub street_name: String"));
245        assert!(!code.contains("sqlx(rename"));
246    }
247
248    // --- types ---
249
250    #[test]
251    fn test_int4_field() {
252        let c = make_composite("data", vec![make_field("count", "int4", false)]);
253        let code = gen(&c);
254        assert!(code.contains("pub count: i32"));
255    }
256
257    #[test]
258    fn test_uuid_field_import() {
259        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
260        let (_, imports) = gen_with(&c, &[], &HashMap::new());
261        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
262    }
263
264    #[test]
265    fn test_text_field() {
266        let c = make_composite("data", vec![make_field("label", "text", false)]);
267        let code = gen(&c);
268        assert!(code.contains("pub label: String"));
269    }
270
271    // --- derives ---
272
273    #[test]
274    fn test_default_derives() {
275        let c = make_composite("data", vec![make_field("x", "text", false)]);
276        let code = gen(&c);
277        assert!(code.contains("Debug"));
278        assert!(code.contains("Clone"));
279        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
280    }
281
282    #[test]
283    fn test_extra_derive() {
284        let c = make_composite("data", vec![make_field("x", "text", false)]);
285        let derives = vec!["Serialize".to_string()];
286        let (code, _) = gen_with(&c, &derives, &HashMap::new());
287        assert!(code.contains("Serialize"));
288    }
289
290    // --- overrides ---
291
292    #[test]
293    fn test_type_override() {
294        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
295        let mut overrides = HashMap::new();
296        overrides.insert("jsonb".to_string(), "MyJson".to_string());
297        let (code, _) = gen_with(&c, &[], &overrides);
298        assert!(code.contains("pub payload: MyJson"));
299    }
300}