Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13    composite: &CompositeTypeInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19    let mut imports = BTreeSet::new();
20    for imp in imports_for_derives(extra_derives) {
21        imports.insert(imp);
22    }
23    let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25    let doc = format!(
26        "Composite type: {}.{}",
27        composite.schema_name, composite.name
28    );
29
30    imports.insert("use serde::{Serialize, Deserialize};".to_string());
31    imports.insert("use sqlx_gen::SqlxGen;".to_string());
32    let mut derive_tokens = vec![
33        quote! { Debug },
34        quote! { Clone },
35        quote! { PartialEq },
36        quote! { Eq },
37        quote! { Serialize },
38        quote! { Deserialize },
39        quote! { sqlx::Type },
40        quote! { SqlxGen },
41    ];
42    for d in extra_derives {
43        let ident = format_ident!("{}", d);
44        derive_tokens.push(quote! { #ident });
45    }
46
47    let pg_name = &composite.name;
48    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
49
50    let fields: Vec<TokenStream> = composite
51        .fields
52        .iter()
53        .map(|col| {
54            let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
55            if let Some(imp) = &rust_type.needs_import {
56                imports.insert(imp.clone());
57            }
58
59            let field_name_snake = col.name.to_snake_case();
60            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
61                let prefixed = format!(
62                    "{}_{}",
63                    composite.name.to_snake_case(),
64                    field_name_snake
65                );
66                (prefixed, true)
67            } else {
68                let changed = field_name_snake != col.name;
69                (field_name_snake, changed)
70            };
71
72            let field_ident = format_ident!("{}", effective_name);
73            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
74                let fallback = format_ident!("String");
75                quote! { #fallback }
76            });
77
78            let rename = if needs_rename {
79                let original = &col.name;
80                quote! { #[sqlx(rename = #original)] }
81            } else {
82                quote! {}
83            };
84
85            quote! {
86                #rename
87                pub #field_ident: #type_tokens,
88            }
89        })
90        .collect();
91
92    let tokens = quote! {
93        #[doc = #doc]
94        #[derive(#(#derive_tokens),*)]
95        #[sqlx_gen(kind = "composite")]
96        #type_attr
97        pub struct #struct_name {
98            #(#fields)*
99        }
100    };
101
102    (tokens, imports)
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::codegen::parse_and_format;
109    use crate::introspect::ColumnInfo;
110
111    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
112        CompositeTypeInfo {
113            schema_name: "public".to_string(),
114            name: name.to_string(),
115            fields,
116        }
117    }
118
119    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
120        ColumnInfo {
121            name: name.to_string(),
122            data_type: udt_name.to_string(),
123            udt_name: udt_name.to_string(),
124            is_nullable: nullable,
125            is_primary_key: false,
126            ordinal_position: 0,
127            schema_name: "public".to_string(),
128            column_default: None,
129        }
130    }
131
132    fn gen(composite: &CompositeTypeInfo) -> String {
133        let schema = SchemaInfo::default();
134        let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
135        parse_and_format(&tokens)
136    }
137
138    fn gen_with(
139        composite: &CompositeTypeInfo,
140        derives: &[String],
141        overrides: &HashMap<String, String>,
142    ) -> (String, BTreeSet<String>) {
143        let schema = SchemaInfo::default();
144        let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
145        (parse_and_format(&tokens), imports)
146    }
147
148    // --- basic structure ---
149
150    #[test]
151    fn test_simple_composite() {
152        let c = make_composite("address", vec![
153            make_field("street", "text", false),
154            make_field("city", "text", false),
155        ]);
156        let code = gen(&c);
157        assert!(code.contains("pub street: String"));
158        assert!(code.contains("pub city: String"));
159    }
160
161    #[test]
162    fn test_name_pascal_case() {
163        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
164        let code = gen(&c);
165        assert!(code.contains("pub struct GeoPoint"));
166    }
167
168    #[test]
169    fn test_doc_comment() {
170        let c = make_composite("address", vec![make_field("x", "text", false)]);
171        let code = gen(&c);
172        assert!(code.contains("Composite type: public.address"));
173    }
174
175    #[test]
176    fn test_sqlx_type_name() {
177        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
178        let code = gen(&c);
179        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
180    }
181
182    #[test]
183    fn test_non_public_schema_qualified_type_name() {
184        let c = CompositeTypeInfo {
185            schema_name: "geo".to_string(),
186            name: "point".to_string(),
187            fields: vec![make_field("x", "float8", false)],
188        };
189        let schema = SchemaInfo::default();
190        let (tokens, _) = generate_composite(&c, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
191        let code = parse_and_format(&tokens);
192        assert!(code.contains("sqlx(type_name = \"point\")"));
193    }
194
195    #[test]
196    fn test_public_schema_not_qualified() {
197        let c = make_composite("address", vec![make_field("x", "text", false)]);
198        let code = gen(&c);
199        assert!(code.contains("sqlx(type_name = \"address\")"));
200        // type_name should NOT be schema-qualified for public schema
201        assert!(!code.contains("type_name = \"public.address\""));
202    }
203
204    // --- fields ---
205
206    #[test]
207    fn test_nullable_field() {
208        let c = make_composite("address", vec![make_field("zip", "text", true)]);
209        let code = gen(&c);
210        assert!(code.contains("Option<String>"));
211    }
212
213    #[test]
214    fn test_non_nullable_field() {
215        let c = make_composite("address", vec![make_field("city", "text", false)]);
216        let code = gen(&c);
217        assert!(code.contains("pub city: String"));
218        assert!(!code.contains("Option"));
219    }
220
221    #[test]
222    fn test_keyword_field_prefixed() {
223        let c = make_composite("item", vec![make_field("type", "text", false)]);
224        let code = gen(&c);
225        assert!(code.contains("pub item_type: String"));
226        assert!(code.contains("sqlx(rename = \"type\")"));
227    }
228
229    // --- rename ---
230
231    #[test]
232    fn test_camel_case_field_renamed() {
233        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
234        let code = gen(&c);
235        assert!(code.contains("pub street_name: String"));
236        assert!(code.contains("sqlx(rename = \"StreetName\")"));
237    }
238
239    #[test]
240    fn test_snake_case_field_no_rename() {
241        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
242        let code = gen(&c);
243        assert!(code.contains("pub street_name: String"));
244        assert!(!code.contains("sqlx(rename"));
245    }
246
247    // --- types ---
248
249    #[test]
250    fn test_int4_field() {
251        let c = make_composite("data", vec![make_field("count", "int4", false)]);
252        let code = gen(&c);
253        assert!(code.contains("pub count: i32"));
254    }
255
256    #[test]
257    fn test_uuid_field_import() {
258        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
259        let (_, imports) = gen_with(&c, &[], &HashMap::new());
260        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
261    }
262
263    #[test]
264    fn test_text_field() {
265        let c = make_composite("data", vec![make_field("label", "text", false)]);
266        let code = gen(&c);
267        assert!(code.contains("pub label: String"));
268    }
269
270    // --- derives ---
271
272    #[test]
273    fn test_default_derives() {
274        let c = make_composite("data", vec![make_field("x", "text", false)]);
275        let code = gen(&c);
276        assert!(code.contains("Debug"));
277        assert!(code.contains("Clone"));
278        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
279    }
280
281    #[test]
282    fn test_extra_derive() {
283        let c = make_composite("data", vec![make_field("x", "text", false)]);
284        let derives = vec!["Serialize".to_string()];
285        let (code, _) = gen_with(&c, &derives, &HashMap::new());
286        assert!(code.contains("Serialize"));
287    }
288
289    // --- overrides ---
290
291    #[test]
292    fn test_type_override() {
293        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
294        let mut overrides = HashMap::new();
295        overrides.insert("jsonb".to_string(), "MyJson".to_string());
296        let (code, _) = gen_with(&c, &[], &overrides);
297        assert!(code.contains("pub payload: MyJson"));
298    }
299}