Skip to main content

sqlx_gen/codegen/
composite_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{CompositeTypeInfo, SchemaInfo};
10use crate::typemap;
11
12pub fn generate_composite(
13    composite: &CompositeTypeInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19    let mut imports = BTreeSet::new();
20    for imp in imports_for_derives(extra_derives) {
21        imports.insert(imp);
22    }
23    let struct_name = format_ident!("{}", composite.name.to_upper_camel_case());
24
25    let doc = format!(
26        "Composite type: {}.{}",
27        composite.schema_name, composite.name
28    );
29
30    let mut derive_tokens = vec![
31        quote! { Debug },
32        quote! { Clone },
33        quote! { sqlx::Type },
34    ];
35    for d in extra_derives {
36        let ident = format_ident!("{}", d);
37        derive_tokens.push(quote! { #ident });
38    }
39
40    let pg_name = &composite.name;
41    let type_attr = quote! { #[sqlx(type_name = #pg_name)] };
42
43    let fields: Vec<TokenStream> = composite
44        .fields
45        .iter()
46        .map(|col| {
47            let rust_type = typemap::map_column(col, db_kind, schema_info, type_overrides);
48            if let Some(imp) = &rust_type.needs_import {
49                imports.insert(imp.clone());
50            }
51
52            let field_name_snake = col.name.to_snake_case();
53            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
54                let prefixed = format!(
55                    "{}_{}",
56                    composite.name.to_snake_case(),
57                    field_name_snake
58                );
59                (prefixed, true)
60            } else {
61                let changed = field_name_snake != col.name;
62                (field_name_snake, changed)
63            };
64
65            let field_ident = format_ident!("{}", effective_name);
66            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
67                let fallback = format_ident!("String");
68                quote! { #fallback }
69            });
70
71            let rename = if needs_rename {
72                let original = &col.name;
73                quote! { #[sqlx(rename = #original)] }
74            } else {
75                quote! {}
76            };
77
78            quote! {
79                #rename
80                pub #field_ident: #type_tokens,
81            }
82        })
83        .collect();
84
85    let tokens = quote! {
86        #[doc = #doc]
87        #[derive(#(#derive_tokens),*)]
88        #type_attr
89        pub struct #struct_name {
90            #(#fields)*
91        }
92    };
93
94    (tokens, imports)
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::codegen::parse_and_format;
101    use crate::introspect::ColumnInfo;
102
103    fn make_composite(name: &str, fields: Vec<ColumnInfo>) -> CompositeTypeInfo {
104        CompositeTypeInfo {
105            schema_name: "public".to_string(),
106            name: name.to_string(),
107            fields,
108        }
109    }
110
111    fn make_field(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
112        ColumnInfo {
113            name: name.to_string(),
114            data_type: udt_name.to_string(),
115            udt_name: udt_name.to_string(),
116            is_nullable: nullable,
117            ordinal_position: 0,
118            schema_name: "public".to_string(),
119        }
120    }
121
122    fn gen(composite: &CompositeTypeInfo) -> String {
123        let schema = SchemaInfo::default();
124        let (tokens, _) = generate_composite(composite, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
125        parse_and_format(&tokens)
126    }
127
128    fn gen_with(
129        composite: &CompositeTypeInfo,
130        derives: &[String],
131        overrides: &HashMap<String, String>,
132    ) -> (String, BTreeSet<String>) {
133        let schema = SchemaInfo::default();
134        let (tokens, imports) = generate_composite(composite, DatabaseKind::Postgres, &schema, derives, overrides);
135        (parse_and_format(&tokens), imports)
136    }
137
138    // --- basic structure ---
139
140    #[test]
141    fn test_simple_composite() {
142        let c = make_composite("address", vec![
143            make_field("street", "text", false),
144            make_field("city", "text", false),
145        ]);
146        let code = gen(&c);
147        assert!(code.contains("pub street: String"));
148        assert!(code.contains("pub city: String"));
149    }
150
151    #[test]
152    fn test_name_pascal_case() {
153        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
154        let code = gen(&c);
155        assert!(code.contains("pub struct GeoPoint"));
156    }
157
158    #[test]
159    fn test_doc_comment() {
160        let c = make_composite("address", vec![make_field("x", "text", false)]);
161        let code = gen(&c);
162        assert!(code.contains("Composite type: public.address"));
163    }
164
165    #[test]
166    fn test_sqlx_type_name() {
167        let c = make_composite("geo_point", vec![make_field("x", "float8", false)]);
168        let code = gen(&c);
169        assert!(code.contains("sqlx(type_name = \"geo_point\")"));
170    }
171
172    // --- fields ---
173
174    #[test]
175    fn test_nullable_field() {
176        let c = make_composite("address", vec![make_field("zip", "text", true)]);
177        let code = gen(&c);
178        assert!(code.contains("Option<String>"));
179    }
180
181    #[test]
182    fn test_non_nullable_field() {
183        let c = make_composite("address", vec![make_field("city", "text", false)]);
184        let code = gen(&c);
185        assert!(code.contains("pub city: String"));
186        assert!(!code.contains("Option"));
187    }
188
189    #[test]
190    fn test_keyword_field_prefixed() {
191        let c = make_composite("item", vec![make_field("type", "text", false)]);
192        let code = gen(&c);
193        assert!(code.contains("pub item_type: String"));
194        assert!(code.contains("sqlx(rename = \"type\")"));
195    }
196
197    // --- rename ---
198
199    #[test]
200    fn test_camel_case_field_renamed() {
201        let c = make_composite("address", vec![make_field("StreetName", "text", false)]);
202        let code = gen(&c);
203        assert!(code.contains("pub street_name: String"));
204        assert!(code.contains("sqlx(rename = \"StreetName\")"));
205    }
206
207    #[test]
208    fn test_snake_case_field_no_rename() {
209        let c = make_composite("address", vec![make_field("street_name", "text", false)]);
210        let code = gen(&c);
211        assert!(code.contains("pub street_name: String"));
212        assert!(!code.contains("sqlx(rename"));
213    }
214
215    // --- types ---
216
217    #[test]
218    fn test_int4_field() {
219        let c = make_composite("data", vec![make_field("count", "int4", false)]);
220        let code = gen(&c);
221        assert!(code.contains("pub count: i32"));
222    }
223
224    #[test]
225    fn test_uuid_field_import() {
226        let c = make_composite("data", vec![make_field("id", "uuid", false)]);
227        let (_, imports) = gen_with(&c, &[], &HashMap::new());
228        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
229    }
230
231    #[test]
232    fn test_text_field() {
233        let c = make_composite("data", vec![make_field("label", "text", false)]);
234        let code = gen(&c);
235        assert!(code.contains("pub label: String"));
236    }
237
238    // --- derives ---
239
240    #[test]
241    fn test_default_derives() {
242        let c = make_composite("data", vec![make_field("x", "text", false)]);
243        let code = gen(&c);
244        assert!(code.contains("Debug"));
245        assert!(code.contains("Clone"));
246        assert!(code.contains("sqlx::Type") || code.contains("sqlx :: Type"));
247    }
248
249    #[test]
250    fn test_extra_derive() {
251        let c = make_composite("data", vec![make_field("x", "text", false)]);
252        let derives = vec!["Serialize".to_string()];
253        let (code, _) = gen_with(&c, &derives, &HashMap::new());
254        assert!(code.contains("Serialize"));
255    }
256
257    // --- overrides ---
258
259    #[test]
260    fn test_type_override() {
261        let c = make_composite("data", vec![make_field("payload", "jsonb", false)]);
262        let mut overrides = HashMap::new();
263        overrides.insert("jsonb".to_string(), "MyJson".to_string());
264        let (code, _) = gen_with(&c, &[], &overrides);
265        assert!(code.contains("pub payload: MyJson"));
266    }
267}