Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19    time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21    let mut imports = BTreeSet::new();
22    for imp in imports_for_derives(extra_derives) {
23        imports.insert(imp);
24    }
25    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27    // Build derive list
28    imports.insert("use serde::{Serialize, Deserialize};".to_string());
29    imports.insert("use sqlx_gen::SqlxGen;".to_string());
30    let mut derive_tokens = vec![
31        quote! { Debug },
32        quote! { Clone },
33        quote! { PartialEq },
34        quote! { Eq },
35        quote! { Serialize },
36        quote! { Deserialize },
37        quote! { sqlx::FromRow },
38        quote! { SqlxGen },
39    ];
40    for d in extra_derives {
41        let ident = format_ident!("{}", d);
42        derive_tokens.push(quote! { #ident });
43    }
44
45    // Build fields
46    let fields: Vec<TokenStream> = table
47        .columns
48        .iter()
49        .map(|col| {
50            let rust_type =
51                resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
52            if let Some(imp) = &rust_type.needs_import {
53                imports.insert(imp.clone());
54            }
55
56            let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
57            // If the field name is a Rust keyword, prefix with table name
58            // e.g. column "type" on table "connector" → "connector_type"
59            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
60                let prefixed = format!("{}_{}", table.name.to_snake_case(), field_name_snake);
61                (prefixed, true)
62            } else {
63                let changed = field_name_snake != col.name;
64                (field_name_snake, changed)
65            };
66
67            let field_ident = format_ident!("{}", effective_name);
68            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
69                let fallback = format_ident!("String");
70                quote! { #fallback }
71            });
72
73            let rename = if needs_rename {
74                let original = &col.name;
75                quote! { #[sqlx(rename = #original)] }
76            } else {
77                quote! {}
78            };
79
80            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array, column_default
81            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
82            let has_pk = col.is_primary_key;
83            let has_sql_type = sql_type.is_some();
84            let has_default = col.column_default.is_some();
85
86            let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
87                let pk_part = if has_pk {
88                    quote! { primary_key, }
89                } else {
90                    quote! {}
91                };
92                let sql_type_part = match &sql_type {
93                    Some(t) => quote! { sql_type = #t, },
94                    None => quote! {},
95                };
96                let array_part = if is_sql_array {
97                    quote! { is_array, }
98                } else {
99                    quote! {}
100                };
101                let default_part = match &col.column_default {
102                    Some(d) => quote! { column_default = #d, },
103                    None => quote! {},
104                };
105                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
106            } else {
107                quote! {}
108            };
109
110            quote! {
111                #sqlx_gen_attr
112                #rename
113                pub #field_ident: #type_tokens,
114            }
115        })
116        .collect();
117
118    let table_name_str = &table.name;
119    let schema_name_str = &table.schema_name;
120    let kind_str = if is_view { "view" } else { "table" };
121
122    let tokens = quote! {
123        #[derive(#(#derive_tokens),*)]
124        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
125        pub struct #struct_name {
126            #(#fields)*
127        }
128    };
129
130    (tokens, imports)
131}
132
133/// Sanitize a candidate Rust identifier:
134/// - replace any character that is not ascii-alphanumeric or '_' with '_'
135/// - prefix with '_' if the result starts with a digit
136/// - fall back to "_field" if the input is empty
137///
138/// Lets sqlx-gen survive columns named `user-id`, `created at`, `123`, etc.
139/// — they still need a `#[sqlx(rename = "<original>")]` to roundtrip the DB
140/// column, which the caller handles via the `changed` flag.
141pub(crate) fn sanitize_rust_ident(name: &str) -> String {
142    if name.is_empty() {
143        return "_field".to_string();
144    }
145    let mut out: String = name
146        .chars()
147        .map(|c| {
148            if c.is_ascii_alphanumeric() || c == '_' {
149                c
150            } else {
151                '_'
152            }
153        })
154        .collect();
155    if out.starts_with(|c: char| c.is_ascii_digit()) {
156        out.insert(0, '_');
157    }
158    out
159}
160
161/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
162/// SQL type name for casting, plus whether it's an array.
163/// Returns `(Some("type_name"), true)` for arrays of custom types,
164/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
165fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
166    let (base_name, is_array) = match udt_name.strip_prefix('_') {
167        Some(inner) => (inner, true),
168        None => (udt_name, false),
169    };
170
171    // Check enums
172    if schema_info.enums.iter().any(|e| e.name == base_name) {
173        return (Some(base_name.to_string()), is_array);
174    }
175
176    // Check composite types
177    if schema_info
178        .composite_types
179        .iter()
180        .any(|c| c.name == base_name)
181    {
182        return (Some(base_name.to_string()), is_array);
183    }
184
185    // Check if this is a non-builtin type that would hit the typemap fallback
186    // (e.g. range types like "timerange", "tsrange", etc.)
187    // Domains resolve to their base type, so they don't need marking.
188    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
189    if !is_domain && !typemap::postgres::is_builtin(base_name) {
190        return (Some(base_name.to_string()), is_array);
191    }
192
193    // Native arrays of builtin types (e.g. `_text` → `text[]`) need sql_type annotation
194    // so that codegen falls back to runtime mode — `query_as!` macro doesn't support
195    // `Vec<T>` without `PgHasArrayType`.
196    if is_array {
197        return (Some(base_name.to_string()), true);
198    }
199
200    (None, false)
201}
202
203fn resolve_column_type(
204    col: &crate::introspect::ColumnInfo,
205    db_kind: DatabaseKind,
206    table: &TableInfo,
207    schema_info: &SchemaInfo,
208    type_overrides: &HashMap<String, String>,
209    time_crate: TimeCrate,
210) -> typemap::RustType {
211    // For MySQL ENUM columns, resolve to the generated enum type
212    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
213        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
214        let rt = typemap::RustType::with_import(
215            &enum_type_name,
216            &format!("use super::types::{};", enum_type_name),
217        );
218        return if col.is_nullable {
219            rt.wrap_option()
220        } else {
221            rt
222        };
223    }
224
225    typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::codegen::parse_and_format;
232    use crate::introspect::ColumnInfo;
233
234    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
235        TableInfo {
236            schema_name: "public".to_string(),
237            name: name.to_string(),
238            columns,
239        }
240    }
241
242    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
243        ColumnInfo {
244            name: name.to_string(),
245            data_type: udt_name.to_string(),
246            udt_name: udt_name.to_string(),
247            is_nullable: nullable,
248            is_primary_key: false,
249            ordinal_position: 0,
250            schema_name: "public".to_string(),
251            udt_schema: None,
252            column_default: None,
253        }
254    }
255
256    fn gen(table: &TableInfo) -> String {
257        let schema = SchemaInfo::default();
258        let (tokens, _) = generate_struct(
259            table,
260            DatabaseKind::Postgres,
261            &schema,
262            &[],
263            &HashMap::new(),
264            false,
265            TimeCrate::Chrono,
266        );
267        parse_and_format(&tokens).unwrap()
268    }
269
270    fn gen_with(
271        table: &TableInfo,
272        schema: &SchemaInfo,
273        db: DatabaseKind,
274        derives: &[String],
275        overrides: &HashMap<String, String>,
276    ) -> (String, BTreeSet<String>) {
277        let (tokens, imports) = generate_struct(
278            table,
279            db,
280            schema,
281            derives,
282            overrides,
283            false,
284            TimeCrate::Chrono,
285        );
286        (parse_and_format(&tokens).unwrap(), imports)
287    }
288
289    // --- basic structure ---
290
291    #[test]
292    fn test_simple_table() {
293        let table = make_table(
294            "users",
295            vec![
296                make_col("id", "int4", false),
297                make_col("name", "text", false),
298            ],
299        );
300        let code = gen(&table);
301        assert!(code.contains("pub id: i32"));
302        assert!(code.contains("pub name: String"));
303    }
304
305    #[test]
306    fn test_struct_name_pascal_case() {
307        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
308        let code = gen(&table);
309        assert!(code.contains("pub struct UserRoles"));
310    }
311
312    #[test]
313    fn test_struct_name_simple() {
314        let table = make_table("users", vec![make_col("id", "int4", false)]);
315        let code = gen(&table);
316        assert!(code.contains("pub struct Users"));
317    }
318
319    // --- nullable ---
320
321    #[test]
322    fn test_nullable_column() {
323        let table = make_table("users", vec![make_col("email", "text", true)]);
324        let code = gen(&table);
325        assert!(code.contains("pub email: Option<String>"));
326    }
327
328    #[test]
329    fn test_non_nullable_column() {
330        let table = make_table("users", vec![make_col("name", "text", false)]);
331        let code = gen(&table);
332        assert!(code.contains("pub name: String"));
333        assert!(!code.contains("Option"));
334    }
335
336    #[test]
337    fn test_mix_nullable() {
338        let table = make_table(
339            "users",
340            vec![make_col("id", "int4", false), make_col("bio", "text", true)],
341        );
342        let code = gen(&table);
343        assert!(code.contains("pub id: i32"));
344        assert!(code.contains("pub bio: Option<String>"));
345    }
346
347    // --- keyword renaming ---
348
349    #[test]
350    fn test_keyword_type_renamed() {
351        let table = make_table("connector", vec![make_col("type", "text", false)]);
352        let code = gen(&table);
353        assert!(code.contains("pub connector_type: String"));
354        assert!(code.contains("sqlx(rename = \"type\")"));
355    }
356
357    #[test]
358    fn test_keyword_fn_renamed() {
359        let table = make_table("item", vec![make_col("fn", "text", false)]);
360        let code = gen(&table);
361        assert!(code.contains("pub item_fn: String"));
362        assert!(code.contains("sqlx(rename = \"fn\")"));
363    }
364
365    #[test]
366    fn test_keyword_match_renamed() {
367        let table = make_table("game", vec![make_col("match", "text", false)]);
368        let code = gen(&table);
369        assert!(code.contains("pub game_match: String"));
370    }
371
372    #[test]
373    fn test_non_keyword_no_rename() {
374        let table = make_table("users", vec![make_col("name", "text", false)]);
375        let code = gen(&table);
376        assert!(!code.contains("sqlx(rename"));
377    }
378
379    // --- snake_case renaming ---
380
381    #[test]
382    fn test_camel_case_column_renamed() {
383        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
384        let code = gen(&table);
385        assert!(code.contains("pub created_at: String"));
386        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
387    }
388
389    #[test]
390    fn test_mixed_case_column_renamed() {
391        let table = make_table("users", vec![make_col("firstName", "text", false)]);
392        let code = gen(&table);
393        assert!(code.contains("pub first_name: String"));
394        assert!(code.contains("sqlx(rename = \"firstName\")"));
395    }
396
397    #[test]
398    fn test_already_snake_case_no_rename() {
399        let table = make_table("users", vec![make_col("created_at", "text", false)]);
400        let code = gen(&table);
401        assert!(code.contains("pub created_at: String"));
402        assert!(!code.contains("sqlx(rename"));
403    }
404
405    // --- derives ---
406
407    #[test]
408    fn test_default_derives() {
409        let table = make_table("users", vec![make_col("id", "int4", false)]);
410        let code = gen(&table);
411        assert!(code.contains("Debug"));
412        assert!(code.contains("Clone"));
413        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
414    }
415
416    #[test]
417    fn test_extra_derive_serialize() {
418        let table = make_table("users", vec![make_col("id", "int4", false)]);
419        let schema = SchemaInfo::default();
420        let derives = vec!["Serialize".to_string()];
421        let (code, _) = gen_with(
422            &table,
423            &schema,
424            DatabaseKind::Postgres,
425            &derives,
426            &HashMap::new(),
427        );
428        assert!(code.contains("Serialize"));
429    }
430
431    #[test]
432    fn test_extra_derives_both_serde() {
433        let table = make_table("users", vec![make_col("id", "int4", false)]);
434        let schema = SchemaInfo::default();
435        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
436        let (_, imports) = gen_with(
437            &table,
438            &schema,
439            DatabaseKind::Postgres,
440            &derives,
441            &HashMap::new(),
442        );
443        assert!(imports.iter().any(|i| i.contains("serde")));
444    }
445
446    // --- imports ---
447
448    #[test]
449    fn test_uuid_import() {
450        let table = make_table("users", vec![make_col("id", "uuid", false)]);
451        let schema = SchemaInfo::default();
452        let (_, imports) = gen_with(
453            &table,
454            &schema,
455            DatabaseKind::Postgres,
456            &[],
457            &HashMap::new(),
458        );
459        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
460    }
461
462    #[test]
463    fn test_timestamptz_import() {
464        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
465        let schema = SchemaInfo::default();
466        let (_, imports) = gen_with(
467            &table,
468            &schema,
469            DatabaseKind::Postgres,
470            &[],
471            &HashMap::new(),
472        );
473        assert!(imports.iter().any(|i| i.contains("chrono")));
474    }
475
476    #[test]
477    fn test_int4_only_serde_import() {
478        let table = make_table("users", vec![make_col("id", "int4", false)]);
479        let schema = SchemaInfo::default();
480        let (_, imports) = gen_with(
481            &table,
482            &schema,
483            DatabaseKind::Postgres,
484            &[],
485            &HashMap::new(),
486        );
487        assert_eq!(imports.len(), 2);
488        assert!(imports.iter().any(|i| i.contains("serde")));
489        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
490    }
491
492    #[test]
493    fn test_multiple_imports_collected() {
494        let table = make_table(
495            "users",
496            vec![
497                make_col("id", "uuid", false),
498                make_col("created_at", "timestamptz", false),
499            ],
500        );
501        let schema = SchemaInfo::default();
502        let (_, imports) = gen_with(
503            &table,
504            &schema,
505            DatabaseKind::Postgres,
506            &[],
507            &HashMap::new(),
508        );
509        assert!(imports.iter().any(|i| i.contains("uuid")));
510        assert!(imports.iter().any(|i| i.contains("chrono")));
511    }
512
513    // --- MySQL enum ---
514
515    #[test]
516    fn test_mysql_enum_column() {
517        let table = make_table(
518            "users",
519            vec![ColumnInfo {
520                name: "status".to_string(),
521                data_type: "enum".to_string(),
522                udt_name: "enum('active','inactive')".to_string(),
523                is_nullable: false,
524                is_primary_key: false,
525                ordinal_position: 0,
526                schema_name: "test_db".to_string(),
527                udt_schema: None,
528                column_default: None,
529            }],
530        );
531        let schema = SchemaInfo::default();
532        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
533        assert!(code.contains("UsersStatus"));
534        assert!(imports.iter().any(|i| i.contains("super::types::")));
535    }
536
537    #[test]
538    fn test_mysql_enum_nullable() {
539        let table = make_table(
540            "users",
541            vec![ColumnInfo {
542                name: "status".to_string(),
543                data_type: "enum".to_string(),
544                udt_name: "enum('a','b')".to_string(),
545                is_nullable: true,
546                is_primary_key: false,
547                ordinal_position: 0,
548                schema_name: "test_db".to_string(),
549                udt_schema: None,
550                column_default: None,
551            }],
552        );
553        let schema = SchemaInfo::default();
554        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
555        assert!(code.contains("Option<UsersStatus>"));
556    }
557
558    // --- type overrides ---
559
560    #[test]
561    fn test_type_override() {
562        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
563        let schema = SchemaInfo::default();
564        let mut overrides = HashMap::new();
565        overrides.insert("jsonb".to_string(), "MyJson".to_string());
566        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
567        assert!(code.contains("pub data: MyJson"));
568    }
569
570    #[test]
571    fn test_type_override_absent() {
572        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
573        let schema = SchemaInfo::default();
574        let (code, _) = gen_with(
575            &table,
576            &schema,
577            DatabaseKind::Postgres,
578            &[],
579            &HashMap::new(),
580        );
581        assert!(code.contains("Value"));
582    }
583
584    #[test]
585    fn test_type_override_nullable() {
586        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
587        let schema = SchemaInfo::default();
588        let mut overrides = HashMap::new();
589        overrides.insert("jsonb".to_string(), "MyJson".to_string());
590        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
591        assert!(code.contains("Option<MyJson>"));
592    }
593
594    // --- native array types ---
595
596    #[test]
597    fn test_native_array_text_gets_sql_type_annotation() {
598        let table = make_table("posts", vec![make_col("tags", "_text", false)]);
599        let code = gen(&table);
600        assert!(code.contains("Vec<String>"));
601        assert!(code.contains("sql_type = \"text\""));
602        assert!(code.contains("is_array"));
603    }
604
605    #[test]
606    fn test_native_array_int4_gets_sql_type_annotation() {
607        let table = make_table("data", vec![make_col("values", "_int4", false)]);
608        let code = gen(&table);
609        assert!(code.contains("Vec<i32>"));
610        assert!(code.contains("sql_type = \"int4\""));
611        assert!(code.contains("is_array"));
612    }
613
614    #[test]
615    fn test_native_array_nullable_gets_sql_type_annotation() {
616        let table = make_table("posts", vec![make_col("tags", "_text", true)]);
617        let code = gen(&table);
618        assert!(code.contains("Option<Vec<String>>"));
619        assert!(code.contains("sql_type = \"text\""));
620        assert!(code.contains("is_array"));
621    }
622
623    #[test]
624    fn test_scalar_builtin_no_sql_type_annotation() {
625        let table = make_table("users", vec![make_col("name", "text", false)]);
626        let code = gen(&table);
627        assert!(code.contains("pub name: String"));
628        assert!(!code.contains("sql_type"));
629    }
630
631    // ========== sanitize_rust_ident ==========
632
633    #[test]
634    fn test_sanitize_replaces_dash() {
635        assert_eq!(sanitize_rust_ident("user-id"), "user_id");
636    }
637
638    #[test]
639    fn test_sanitize_replaces_space() {
640        assert_eq!(sanitize_rust_ident("created at"), "created_at");
641    }
642
643    #[test]
644    fn test_sanitize_replaces_dot() {
645        assert_eq!(sanitize_rust_ident("a.b"), "a_b");
646    }
647
648    #[test]
649    fn test_sanitize_prefixes_leading_digit() {
650        assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
651    }
652
653    #[test]
654    fn test_sanitize_empty_becomes_placeholder() {
655        assert_eq!(sanitize_rust_ident(""), "_field");
656    }
657
658    #[test]
659    fn test_sanitize_leaves_valid_ident_unchanged() {
660        assert_eq!(sanitize_rust_ident("user_id"), "user_id");
661        assert_eq!(sanitize_rust_ident("_private"), "_private");
662    }
663
664    #[test]
665    fn test_column_with_dash_generates_valid_rust() {
666        let table = make_table("users", vec![make_col("user-id", "int4", false)]);
667        let code = gen(&table);
668        // Must produce a Rust-legal identifier; renamed back to the original via #[sqlx(rename)]
669        assert!(
670            code.contains("pub user_id:") || code.contains("user_id:"),
671            "expected sanitized identifier, got:\n{}",
672            code
673        );
674        assert!(code.contains("sqlx(rename = \"user-id\")"));
675    }
676}