Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19    time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21    let mut imports = BTreeSet::new();
22    for imp in imports_for_derives(extra_derives) {
23        imports.insert(imp);
24    }
25    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27    // Build derive list
28    imports.insert("use serde::{Serialize, Deserialize};".to_string());
29    imports.insert("use sqlx_gen::SqlxGen;".to_string());
30    let mut derive_tokens = vec![
31        quote! { Debug },
32        quote! { Clone },
33        quote! { PartialEq },
34        quote! { Eq },
35        quote! { Serialize },
36        quote! { Deserialize },
37        quote! { sqlx::FromRow },
38        quote! { SqlxGen },
39    ];
40    for d in extra_derives {
41        let ident = format_ident!("{}", d);
42        derive_tokens.push(quote! { #ident });
43    }
44
45    // Build fields
46    let fields: Vec<TokenStream> = table
47        .columns
48        .iter()
49        .map(|col| {
50            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
51            if let Some(imp) = &rust_type.needs_import {
52                imports.insert(imp.clone());
53            }
54
55            let field_name_snake = col.name.to_snake_case();
56            // If the field name is a Rust keyword, prefix with table name
57            // e.g. column "type" on table "connector" → "connector_type"
58            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
59                let prefixed = format!(
60                    "{}_{}",
61                    table.name.to_snake_case(),
62                    field_name_snake
63                );
64                (prefixed, true)
65            } else {
66                let changed = field_name_snake != col.name;
67                (field_name_snake, changed)
68            };
69
70            let field_ident = format_ident!("{}", effective_name);
71            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72                let fallback = format_ident!("String");
73                quote! { #fallback }
74            });
75
76            let rename = if needs_rename {
77                let original = &col.name;
78                quote! { #[sqlx(rename = #original)] }
79            } else {
80                quote! {}
81            };
82
83            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array, column_default
84            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85            let has_pk = col.is_primary_key;
86            let has_sql_type = sql_type.is_some();
87            let has_default = col.column_default.is_some();
88
89            let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
90                let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
91                let sql_type_part = match &sql_type {
92                    Some(t) => quote! { sql_type = #t, },
93                    None => quote! {},
94                };
95                let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
96                let default_part = match &col.column_default {
97                    Some(d) => quote! { column_default = #d, },
98                    None => quote! {},
99                };
100                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
101            } else {
102                quote! {}
103            };
104
105            quote! {
106                #sqlx_gen_attr
107                #rename
108                pub #field_ident: #type_tokens,
109            }
110        })
111        .collect();
112
113    let table_name_str = &table.name;
114    let schema_name_str = &table.schema_name;
115    let kind_str = if is_view { "view" } else { "table" };
116
117    let tokens = quote! {
118        #[derive(#(#derive_tokens),*)]
119        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
120        pub struct #struct_name {
121            #(#fields)*
122        }
123    };
124
125    (tokens, imports)
126}
127
128/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
129/// SQL type name for casting, plus whether it's an array.
130/// Returns `(Some("type_name"), true)` for arrays of custom types,
131/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
132fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
133    let (base_name, is_array) = match udt_name.strip_prefix('_') {
134        Some(inner) => (inner, true),
135        None => (udt_name, false),
136    };
137
138    // Check enums
139    if schema_info.enums.iter().any(|e| e.name == base_name) {
140        return (Some(base_name.to_string()), is_array);
141    }
142
143    // Check composite types
144    if schema_info.composite_types.iter().any(|c| c.name == base_name) {
145        return (Some(base_name.to_string()), is_array);
146    }
147
148    // Check if this is a non-builtin type that would hit the typemap fallback
149    // (e.g. range types like "timerange", "tsrange", etc.)
150    // Domains resolve to their base type, so they don't need marking.
151    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
152    if !is_domain && !typemap::postgres::is_builtin(base_name) {
153        return (Some(base_name.to_string()), is_array);
154    }
155
156    // Native arrays of builtin types (e.g. `_text` → `text[]`) need sql_type annotation
157    // so that codegen falls back to runtime mode — `query_as!` macro doesn't support
158    // `Vec<T>` without `PgHasArrayType`.
159    if is_array {
160        return (Some(base_name.to_string()), true);
161    }
162
163    (None, false)
164}
165
166fn resolve_column_type(
167    col: &crate::introspect::ColumnInfo,
168    db_kind: DatabaseKind,
169    table: &TableInfo,
170    schema_info: &SchemaInfo,
171    type_overrides: &HashMap<String, String>,
172    time_crate: TimeCrate,
173) -> typemap::RustType {
174    // For MySQL ENUM columns, resolve to the generated enum type
175    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
176        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
177        let rt = typemap::RustType::with_import(
178            &enum_type_name,
179            &format!("use super::types::{};", enum_type_name),
180        );
181        return if col.is_nullable {
182            rt.wrap_option()
183        } else {
184            rt
185        };
186    }
187
188    typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use crate::codegen::parse_and_format;
195    use crate::introspect::ColumnInfo;
196
197    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
198        TableInfo {
199            schema_name: "public".to_string(),
200            name: name.to_string(),
201            columns,
202        }
203    }
204
205    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
206        ColumnInfo {
207            name: name.to_string(),
208            data_type: udt_name.to_string(),
209            udt_name: udt_name.to_string(),
210            is_nullable: nullable,
211            is_primary_key: false,
212            ordinal_position: 0,
213            schema_name: "public".to_string(),
214            column_default: None,
215        }
216    }
217
218    fn gen(table: &TableInfo) -> String {
219        let schema = SchemaInfo::default();
220        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false, TimeCrate::Chrono);
221        parse_and_format(&tokens)
222    }
223
224    fn gen_with(
225        table: &TableInfo,
226        schema: &SchemaInfo,
227        db: DatabaseKind,
228        derives: &[String],
229        overrides: &HashMap<String, String>,
230    ) -> (String, BTreeSet<String>) {
231        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false, TimeCrate::Chrono);
232        (parse_and_format(&tokens), imports)
233    }
234
235    // --- basic structure ---
236
237    #[test]
238    fn test_simple_table() {
239        let table = make_table("users", vec![
240            make_col("id", "int4", false),
241            make_col("name", "text", false),
242        ]);
243        let code = gen(&table);
244        assert!(code.contains("pub id: i32"));
245        assert!(code.contains("pub name: String"));
246    }
247
248    #[test]
249    fn test_struct_name_pascal_case() {
250        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
251        let code = gen(&table);
252        assert!(code.contains("pub struct UserRoles"));
253    }
254
255    #[test]
256    fn test_struct_name_simple() {
257        let table = make_table("users", vec![make_col("id", "int4", false)]);
258        let code = gen(&table);
259        assert!(code.contains("pub struct Users"));
260    }
261
262    // --- nullable ---
263
264    #[test]
265    fn test_nullable_column() {
266        let table = make_table("users", vec![make_col("email", "text", true)]);
267        let code = gen(&table);
268        assert!(code.contains("pub email: Option<String>"));
269    }
270
271    #[test]
272    fn test_non_nullable_column() {
273        let table = make_table("users", vec![make_col("name", "text", false)]);
274        let code = gen(&table);
275        assert!(code.contains("pub name: String"));
276        assert!(!code.contains("Option"));
277    }
278
279    #[test]
280    fn test_mix_nullable() {
281        let table = make_table("users", vec![
282            make_col("id", "int4", false),
283            make_col("bio", "text", true),
284        ]);
285        let code = gen(&table);
286        assert!(code.contains("pub id: i32"));
287        assert!(code.contains("pub bio: Option<String>"));
288    }
289
290    // --- keyword renaming ---
291
292    #[test]
293    fn test_keyword_type_renamed() {
294        let table = make_table("connector", vec![make_col("type", "text", false)]);
295        let code = gen(&table);
296        assert!(code.contains("pub connector_type: String"));
297        assert!(code.contains("sqlx(rename = \"type\")"));
298    }
299
300    #[test]
301    fn test_keyword_fn_renamed() {
302        let table = make_table("item", vec![make_col("fn", "text", false)]);
303        let code = gen(&table);
304        assert!(code.contains("pub item_fn: String"));
305        assert!(code.contains("sqlx(rename = \"fn\")"));
306    }
307
308    #[test]
309    fn test_keyword_match_renamed() {
310        let table = make_table("game", vec![make_col("match", "text", false)]);
311        let code = gen(&table);
312        assert!(code.contains("pub game_match: String"));
313    }
314
315    #[test]
316    fn test_non_keyword_no_rename() {
317        let table = make_table("users", vec![make_col("name", "text", false)]);
318        let code = gen(&table);
319        assert!(!code.contains("sqlx(rename"));
320    }
321
322    // --- snake_case renaming ---
323
324    #[test]
325    fn test_camel_case_column_renamed() {
326        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
327        let code = gen(&table);
328        assert!(code.contains("pub created_at: String"));
329        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
330    }
331
332    #[test]
333    fn test_mixed_case_column_renamed() {
334        let table = make_table("users", vec![make_col("firstName", "text", false)]);
335        let code = gen(&table);
336        assert!(code.contains("pub first_name: String"));
337        assert!(code.contains("sqlx(rename = \"firstName\")"));
338    }
339
340    #[test]
341    fn test_already_snake_case_no_rename() {
342        let table = make_table("users", vec![make_col("created_at", "text", false)]);
343        let code = gen(&table);
344        assert!(code.contains("pub created_at: String"));
345        assert!(!code.contains("sqlx(rename"));
346    }
347
348    // --- derives ---
349
350    #[test]
351    fn test_default_derives() {
352        let table = make_table("users", vec![make_col("id", "int4", false)]);
353        let code = gen(&table);
354        assert!(code.contains("Debug"));
355        assert!(code.contains("Clone"));
356        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
357    }
358
359    #[test]
360    fn test_extra_derive_serialize() {
361        let table = make_table("users", vec![make_col("id", "int4", false)]);
362        let schema = SchemaInfo::default();
363        let derives = vec!["Serialize".to_string()];
364        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
365        assert!(code.contains("Serialize"));
366    }
367
368    #[test]
369    fn test_extra_derives_both_serde() {
370        let table = make_table("users", vec![make_col("id", "int4", false)]);
371        let schema = SchemaInfo::default();
372        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
373        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
374        assert!(imports.iter().any(|i| i.contains("serde")));
375    }
376
377    // --- imports ---
378
379    #[test]
380    fn test_uuid_import() {
381        let table = make_table("users", vec![make_col("id", "uuid", false)]);
382        let schema = SchemaInfo::default();
383        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
384        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
385    }
386
387    #[test]
388    fn test_timestamptz_import() {
389        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
390        let schema = SchemaInfo::default();
391        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
392        assert!(imports.iter().any(|i| i.contains("chrono")));
393    }
394
395    #[test]
396    fn test_int4_only_serde_import() {
397        let table = make_table("users", vec![make_col("id", "int4", false)]);
398        let schema = SchemaInfo::default();
399        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
400        assert_eq!(imports.len(), 2);
401        assert!(imports.iter().any(|i| i.contains("serde")));
402        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
403    }
404
405    #[test]
406    fn test_multiple_imports_collected() {
407        let table = make_table("users", vec![
408            make_col("id", "uuid", false),
409            make_col("created_at", "timestamptz", false),
410        ]);
411        let schema = SchemaInfo::default();
412        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
413        assert!(imports.iter().any(|i| i.contains("uuid")));
414        assert!(imports.iter().any(|i| i.contains("chrono")));
415    }
416
417    // --- MySQL enum ---
418
419    #[test]
420    fn test_mysql_enum_column() {
421        let table = make_table("users", vec![ColumnInfo {
422            name: "status".to_string(),
423            data_type: "enum".to_string(),
424            udt_name: "enum('active','inactive')".to_string(),
425            is_nullable: false,
426            is_primary_key: false,
427            ordinal_position: 0,
428            schema_name: "test_db".to_string(),
429            column_default: None,
430        }]);
431        let schema = SchemaInfo::default();
432        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
433        assert!(code.contains("UsersStatus"));
434        assert!(imports.iter().any(|i| i.contains("super::types::")));
435    }
436
437    #[test]
438    fn test_mysql_enum_nullable() {
439        let table = make_table("users", vec![ColumnInfo {
440            name: "status".to_string(),
441            data_type: "enum".to_string(),
442            udt_name: "enum('a','b')".to_string(),
443            is_nullable: true,
444            is_primary_key: false,
445            ordinal_position: 0,
446            schema_name: "test_db".to_string(),
447            column_default: None,
448        }]);
449        let schema = SchemaInfo::default();
450        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
451        assert!(code.contains("Option<UsersStatus>"));
452    }
453
454    // --- type overrides ---
455
456    #[test]
457    fn test_type_override() {
458        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
459        let schema = SchemaInfo::default();
460        let mut overrides = HashMap::new();
461        overrides.insert("jsonb".to_string(), "MyJson".to_string());
462        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
463        assert!(code.contains("pub data: MyJson"));
464    }
465
466    #[test]
467    fn test_type_override_absent() {
468        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
469        let schema = SchemaInfo::default();
470        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
471        assert!(code.contains("Value"));
472    }
473
474    #[test]
475    fn test_type_override_nullable() {
476        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
477        let schema = SchemaInfo::default();
478        let mut overrides = HashMap::new();
479        overrides.insert("jsonb".to_string(), "MyJson".to_string());
480        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
481        assert!(code.contains("Option<MyJson>"));
482    }
483
484    // --- native array types ---
485
486    #[test]
487    fn test_native_array_text_gets_sql_type_annotation() {
488        let table = make_table("posts", vec![make_col("tags", "_text", false)]);
489        let code = gen(&table);
490        assert!(code.contains("Vec<String>"));
491        assert!(code.contains("sql_type = \"text\""));
492        assert!(code.contains("is_array"));
493    }
494
495    #[test]
496    fn test_native_array_int4_gets_sql_type_annotation() {
497        let table = make_table("data", vec![make_col("values", "_int4", false)]);
498        let code = gen(&table);
499        assert!(code.contains("Vec<i32>"));
500        assert!(code.contains("sql_type = \"int4\""));
501        assert!(code.contains("is_array"));
502    }
503
504    #[test]
505    fn test_native_array_nullable_gets_sql_type_annotation() {
506        let table = make_table("posts", vec![make_col("tags", "_text", true)]);
507        let code = gen(&table);
508        assert!(code.contains("Option<Vec<String>>"));
509        assert!(code.contains("sql_type = \"text\""));
510        assert!(code.contains("is_array"));
511    }
512
513    #[test]
514    fn test_scalar_builtin_no_sql_type_annotation() {
515        let table = make_table("users", vec![make_col("name", "text", false)]);
516        let code = gen(&table);
517        assert!(code.contains("pub name: String"));
518        assert!(!code.contains("sql_type"));
519    }
520}