Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26    // Build derive list
27    imports.insert("use serde::{Serialize, Deserialize};".to_string());
28    imports.insert("use sqlx_gen::SqlxGen;".to_string());
29    let mut derive_tokens = vec![
30        quote! { Debug },
31        quote! { Clone },
32        quote! { PartialEq },
33        quote! { Eq },
34        quote! { Serialize },
35        quote! { Deserialize },
36        quote! { sqlx::FromRow },
37        quote! { SqlxGen },
38    ];
39    for d in extra_derives {
40        let ident = format_ident!("{}", d);
41        derive_tokens.push(quote! { #ident });
42    }
43
44    // Build fields
45    let fields: Vec<TokenStream> = table
46        .columns
47        .iter()
48        .map(|col| {
49            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50            if let Some(imp) = &rust_type.needs_import {
51                imports.insert(imp.clone());
52            }
53
54            let field_name_snake = col.name.to_snake_case();
55            // If the field name is a Rust keyword, prefix with table name
56            // e.g. column "type" on table "connector" → "connector_type"
57            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58                let prefixed = format!(
59                    "{}_{}",
60                    table.name.to_snake_case(),
61                    field_name_snake
62                );
63                (prefixed, true)
64            } else {
65                let changed = field_name_snake != col.name;
66                (field_name_snake, changed)
67            };
68
69            let field_ident = format_ident!("{}", effective_name);
70            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71                let fallback = format_ident!("String");
72                quote! { #fallback }
73            });
74
75            let rename = if needs_rename {
76                let original = &col.name;
77                quote! { #[sqlx(rename = #original)] }
78            } else {
79                quote! {}
80            };
81
82            let pk_attr = if col.is_primary_key {
83                quote! { #[sqlx_gen(primary_key)] }
84            } else {
85                quote! {}
86            };
87
88            quote! {
89                #pk_attr
90                #rename
91                pub #field_ident: #type_tokens,
92            }
93        })
94        .collect();
95
96    let table_name_str = &table.name;
97    let schema_name_str = &table.schema_name;
98    let kind_str = if is_view { "view" } else { "table" };
99
100    let tokens = quote! {
101        #[derive(#(#derive_tokens),*)]
102        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
103        pub struct #struct_name {
104            #(#fields)*
105        }
106    };
107
108    (tokens, imports)
109}
110
111fn resolve_column_type(
112    col: &crate::introspect::ColumnInfo,
113    db_kind: DatabaseKind,
114    table: &TableInfo,
115    schema_info: &SchemaInfo,
116    type_overrides: &HashMap<String, String>,
117) -> typemap::RustType {
118    // For MySQL ENUM columns, resolve to the generated enum type
119    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
120        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
121        let rt = typemap::RustType::with_import(
122            &enum_type_name,
123            &format!("use super::types::{};", enum_type_name),
124        );
125        return if col.is_nullable {
126            rt.wrap_option()
127        } else {
128            rt
129        };
130    }
131
132    typemap::map_column(col, db_kind, schema_info, type_overrides)
133}
134
135#[cfg(test)]
136mod tests {
137    use super::*;
138    use crate::codegen::parse_and_format;
139    use crate::introspect::ColumnInfo;
140
141    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
142        TableInfo {
143            schema_name: "public".to_string(),
144            name: name.to_string(),
145            columns,
146        }
147    }
148
149    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
150        ColumnInfo {
151            name: name.to_string(),
152            data_type: udt_name.to_string(),
153            udt_name: udt_name.to_string(),
154            is_nullable: nullable,
155            is_primary_key: false,
156            ordinal_position: 0,
157            schema_name: "public".to_string(),
158        }
159    }
160
161    fn gen(table: &TableInfo) -> String {
162        let schema = SchemaInfo::default();
163        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
164        parse_and_format(&tokens)
165    }
166
167    fn gen_with(
168        table: &TableInfo,
169        schema: &SchemaInfo,
170        db: DatabaseKind,
171        derives: &[String],
172        overrides: &HashMap<String, String>,
173    ) -> (String, BTreeSet<String>) {
174        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
175        (parse_and_format(&tokens), imports)
176    }
177
178    // --- basic structure ---
179
180    #[test]
181    fn test_simple_table() {
182        let table = make_table("users", vec![
183            make_col("id", "int4", false),
184            make_col("name", "text", false),
185        ]);
186        let code = gen(&table);
187        assert!(code.contains("pub id: i32"));
188        assert!(code.contains("pub name: String"));
189    }
190
191    #[test]
192    fn test_struct_name_pascal_case() {
193        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
194        let code = gen(&table);
195        assert!(code.contains("pub struct UserRoles"));
196    }
197
198    #[test]
199    fn test_struct_name_simple() {
200        let table = make_table("users", vec![make_col("id", "int4", false)]);
201        let code = gen(&table);
202        assert!(code.contains("pub struct Users"));
203    }
204
205    // --- nullable ---
206
207    #[test]
208    fn test_nullable_column() {
209        let table = make_table("users", vec![make_col("email", "text", true)]);
210        let code = gen(&table);
211        assert!(code.contains("pub email: Option<String>"));
212    }
213
214    #[test]
215    fn test_non_nullable_column() {
216        let table = make_table("users", vec![make_col("name", "text", false)]);
217        let code = gen(&table);
218        assert!(code.contains("pub name: String"));
219        assert!(!code.contains("Option"));
220    }
221
222    #[test]
223    fn test_mix_nullable() {
224        let table = make_table("users", vec![
225            make_col("id", "int4", false),
226            make_col("bio", "text", true),
227        ]);
228        let code = gen(&table);
229        assert!(code.contains("pub id: i32"));
230        assert!(code.contains("pub bio: Option<String>"));
231    }
232
233    // --- keyword renaming ---
234
235    #[test]
236    fn test_keyword_type_renamed() {
237        let table = make_table("connector", vec![make_col("type", "text", false)]);
238        let code = gen(&table);
239        assert!(code.contains("pub connector_type: String"));
240        assert!(code.contains("sqlx(rename = \"type\")"));
241    }
242
243    #[test]
244    fn test_keyword_fn_renamed() {
245        let table = make_table("item", vec![make_col("fn", "text", false)]);
246        let code = gen(&table);
247        assert!(code.contains("pub item_fn: String"));
248        assert!(code.contains("sqlx(rename = \"fn\")"));
249    }
250
251    #[test]
252    fn test_keyword_match_renamed() {
253        let table = make_table("game", vec![make_col("match", "text", false)]);
254        let code = gen(&table);
255        assert!(code.contains("pub game_match: String"));
256    }
257
258    #[test]
259    fn test_non_keyword_no_rename() {
260        let table = make_table("users", vec![make_col("name", "text", false)]);
261        let code = gen(&table);
262        assert!(!code.contains("sqlx(rename"));
263    }
264
265    // --- snake_case renaming ---
266
267    #[test]
268    fn test_camel_case_column_renamed() {
269        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
270        let code = gen(&table);
271        assert!(code.contains("pub created_at: String"));
272        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
273    }
274
275    #[test]
276    fn test_mixed_case_column_renamed() {
277        let table = make_table("users", vec![make_col("firstName", "text", false)]);
278        let code = gen(&table);
279        assert!(code.contains("pub first_name: String"));
280        assert!(code.contains("sqlx(rename = \"firstName\")"));
281    }
282
283    #[test]
284    fn test_already_snake_case_no_rename() {
285        let table = make_table("users", vec![make_col("created_at", "text", false)]);
286        let code = gen(&table);
287        assert!(code.contains("pub created_at: String"));
288        assert!(!code.contains("sqlx(rename"));
289    }
290
291    // --- derives ---
292
293    #[test]
294    fn test_default_derives() {
295        let table = make_table("users", vec![make_col("id", "int4", false)]);
296        let code = gen(&table);
297        assert!(code.contains("Debug"));
298        assert!(code.contains("Clone"));
299        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
300    }
301
302    #[test]
303    fn test_extra_derive_serialize() {
304        let table = make_table("users", vec![make_col("id", "int4", false)]);
305        let schema = SchemaInfo::default();
306        let derives = vec!["Serialize".to_string()];
307        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
308        assert!(code.contains("Serialize"));
309    }
310
311    #[test]
312    fn test_extra_derives_both_serde() {
313        let table = make_table("users", vec![make_col("id", "int4", false)]);
314        let schema = SchemaInfo::default();
315        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
316        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
317        assert!(imports.iter().any(|i| i.contains("serde")));
318    }
319
320    // --- imports ---
321
322    #[test]
323    fn test_uuid_import() {
324        let table = make_table("users", vec![make_col("id", "uuid", false)]);
325        let schema = SchemaInfo::default();
326        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
327        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
328    }
329
330    #[test]
331    fn test_timestamptz_import() {
332        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
333        let schema = SchemaInfo::default();
334        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
335        assert!(imports.iter().any(|i| i.contains("chrono")));
336    }
337
338    #[test]
339    fn test_int4_only_serde_import() {
340        let table = make_table("users", vec![make_col("id", "int4", false)]);
341        let schema = SchemaInfo::default();
342        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
343        assert_eq!(imports.len(), 2);
344        assert!(imports.iter().any(|i| i.contains("serde")));
345        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
346    }
347
348    #[test]
349    fn test_multiple_imports_collected() {
350        let table = make_table("users", vec![
351            make_col("id", "uuid", false),
352            make_col("created_at", "timestamptz", false),
353        ]);
354        let schema = SchemaInfo::default();
355        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
356        assert!(imports.iter().any(|i| i.contains("uuid")));
357        assert!(imports.iter().any(|i| i.contains("chrono")));
358    }
359
360    // --- MySQL enum ---
361
362    #[test]
363    fn test_mysql_enum_column() {
364        let table = make_table("users", vec![ColumnInfo {
365            name: "status".to_string(),
366            data_type: "enum".to_string(),
367            udt_name: "enum('active','inactive')".to_string(),
368            is_nullable: false,
369            is_primary_key: false,
370            ordinal_position: 0,
371            schema_name: "test_db".to_string(),
372        }]);
373        let schema = SchemaInfo::default();
374        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
375        assert!(code.contains("UsersStatus"));
376        assert!(imports.iter().any(|i| i.contains("super::types::")));
377    }
378
379    #[test]
380    fn test_mysql_enum_nullable() {
381        let table = make_table("users", vec![ColumnInfo {
382            name: "status".to_string(),
383            data_type: "enum".to_string(),
384            udt_name: "enum('a','b')".to_string(),
385            is_nullable: true,
386            is_primary_key: false,
387            ordinal_position: 0,
388            schema_name: "test_db".to_string(),
389        }]);
390        let schema = SchemaInfo::default();
391        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
392        assert!(code.contains("Option<UsersStatus>"));
393    }
394
395    // --- type overrides ---
396
397    #[test]
398    fn test_type_override() {
399        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
400        let schema = SchemaInfo::default();
401        let mut overrides = HashMap::new();
402        overrides.insert("jsonb".to_string(), "MyJson".to_string());
403        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
404        assert!(code.contains("pub data: MyJson"));
405    }
406
407    #[test]
408    fn test_type_override_absent() {
409        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
410        let schema = SchemaInfo::default();
411        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
412        assert!(code.contains("Value"));
413    }
414
415    #[test]
416    fn test_type_override_nullable() {
417        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
418        let schema = SchemaInfo::default();
419        let mut overrides = HashMap::new();
420        overrides.insert("jsonb".to_string(), "MyJson".to_string());
421        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
422        assert!(code.contains("Option<MyJson>"));
423    }
424}