Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword};
10use crate::introspect::{SchemaInfo, TableInfo};
11use crate::typemap;
12
13pub fn generate_struct(
14    table: &TableInfo,
15    db_kind: DatabaseKind,
16    schema_info: &SchemaInfo,
17    extra_derives: &[String],
18    type_overrides: &HashMap<String, String>,
19    is_view: bool,
20    time_crate: TimeCrate,
21) -> (TokenStream, BTreeSet<String>) {
22    let mut imports = BTreeSet::new();
23    for imp in imports_for_derives(extra_derives) {
24        imports.insert(imp);
25    }
26    // Tables are conventionally plural ("users"), structs singular ("User"),
27    // matching the Rust ORM ecosystem (Diesel, SeaORM) and ActiveRecord.
28    let struct_name = format_ident!("{}", singularize(&table.name).to_upper_camel_case());
29
30    // Build derive list
31    imports.insert("use serde::{Serialize, Deserialize};".to_string());
32    imports.insert("use sqlx_gen::SqlxGen;".to_string());
33    let mut derive_tokens = vec![
34        quote! { Debug },
35        quote! { Clone },
36        quote! { PartialEq },
37        quote! { Eq },
38        quote! { Serialize },
39        quote! { Deserialize },
40        quote! { sqlx::FromRow },
41        quote! { SqlxGen },
42    ];
43    for d in extra_derives {
44        let ident = format_ident!("{}", d);
45        derive_tokens.push(quote! { #ident });
46    }
47
48    // Build fields
49    let fields: Vec<TokenStream> = table
50        .columns
51        .iter()
52        .map(|col| {
53            let rust_type =
54                resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
55            if let Some(imp) = &rust_type.needs_import {
56                imports.insert(imp.clone());
57            }
58
59            let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
60            // If the field name is a Rust keyword, prefix with the singular
61            // form of the table name so column "type" on table "products"
62            // becomes "product_type", not "products_type".
63            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
64                let prefix = singularize(&table.name).to_snake_case();
65                let prefixed = format!("{}_{}", prefix, field_name_snake);
66                (prefixed, true)
67            } else {
68                let changed = field_name_snake != col.name;
69                (field_name_snake, changed)
70            };
71
72            let field_ident = format_ident!("{}", effective_name);
73            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
74                let fallback = format_ident!("String");
75                quote! { #fallback }
76            });
77
78            let rename = if needs_rename {
79                let original = &col.name;
80                quote! { #[sqlx(rename = #original)] }
81            } else {
82                quote! {}
83            };
84
85            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array, column_default
86            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
87            let has_pk = col.is_primary_key;
88            let has_sql_type = sql_type.is_some();
89            let has_default = col.column_default.is_some();
90
91            let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
92                let pk_part = if has_pk {
93                    quote! { primary_key, }
94                } else {
95                    quote! {}
96                };
97                let sql_type_part = match &sql_type {
98                    Some(t) => quote! { sql_type = #t, },
99                    None => quote! {},
100                };
101                let array_part = if is_sql_array {
102                    quote! { is_array, }
103                } else {
104                    quote! {}
105                };
106                let default_part = match &col.column_default {
107                    Some(d) => quote! { column_default = #d, },
108                    None => quote! {},
109                };
110                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
111            } else {
112                quote! {}
113            };
114
115            quote! {
116                #sqlx_gen_attr
117                #rename
118                pub #field_ident: #type_tokens,
119            }
120        })
121        .collect();
122
123    let table_name_str = &table.name;
124    let schema_name_str = &table.schema_name;
125    let kind_str = if is_view { "view" } else { "table" };
126
127    let tokens = quote! {
128        #[derive(#(#derive_tokens),*)]
129        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
130        pub struct #struct_name {
131            #(#fields)*
132        }
133    };
134
135    (tokens, imports)
136}
137
138/// Sanitize a candidate Rust identifier:
139/// - replace any character that is not ascii-alphanumeric or '_' with '_'
140/// - prefix with '_' if the result starts with a digit
141/// - fall back to "_field" if the input is empty
142///
143/// Lets sqlx-gen survive columns named `user-id`, `created at`, `123`, etc.
144/// — they still need a `#[sqlx(rename = "<original>")]` to roundtrip the DB
145/// column, which the caller handles via the `changed` flag.
146pub(crate) fn sanitize_rust_ident(name: &str) -> String {
147    if name.is_empty() {
148        return "_field".to_string();
149    }
150    let mut out: String = name
151        .chars()
152        .map(|c| {
153            if c.is_ascii_alphanumeric() || c == '_' {
154                c
155            } else {
156                '_'
157            }
158        })
159        .collect();
160    if out.starts_with(|c: char| c.is_ascii_digit()) {
161        out.insert(0, '_');
162    }
163    out
164}
165
166/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
167/// SQL type name for casting, plus whether it's an array.
168/// Returns `(Some("type_name"), true)` for arrays of custom types,
169/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
170fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
171    let (base_name, is_array) = match udt_name.strip_prefix('_') {
172        Some(inner) => (inner, true),
173        None => (udt_name, false),
174    };
175
176    // Check enums
177    if schema_info.enums.iter().any(|e| e.name == base_name) {
178        return (Some(base_name.to_string()), is_array);
179    }
180
181    // Check composite types
182    if schema_info
183        .composite_types
184        .iter()
185        .any(|c| c.name == base_name)
186    {
187        return (Some(base_name.to_string()), is_array);
188    }
189
190    // Check if this is a non-builtin type that would hit the typemap fallback
191    // (e.g. range types like "timerange", "tsrange", etc.)
192    // Domains resolve to their base type, so they don't need marking.
193    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
194    if !is_domain && !typemap::postgres::is_builtin(base_name) {
195        return (Some(base_name.to_string()), is_array);
196    }
197
198    // Native arrays of builtin types (e.g. `_text` → `text[]`) need sql_type annotation
199    // so that codegen falls back to runtime mode — `query_as!` macro doesn't support
200    // `Vec<T>` without `PgHasArrayType`.
201    if is_array {
202        return (Some(base_name.to_string()), true);
203    }
204
205    (None, false)
206}
207
208fn resolve_column_type(
209    col: &crate::introspect::ColumnInfo,
210    db_kind: DatabaseKind,
211    table: &TableInfo,
212    schema_info: &SchemaInfo,
213    type_overrides: &HashMap<String, String>,
214    time_crate: TimeCrate,
215) -> typemap::RustType {
216    // For MySQL ENUM columns, resolve to the generated enum type
217    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
218        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
219        let rt = typemap::RustType::with_import(
220            &enum_type_name,
221            &format!("use super::types::{};", enum_type_name),
222        );
223        return if col.is_nullable {
224            rt.wrap_option()
225        } else {
226            rt
227        };
228    }
229
230    typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use crate::codegen::parse_and_format;
237    use crate::introspect::ColumnInfo;
238
239    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
240        TableInfo {
241            schema_name: "public".to_string(),
242            name: name.to_string(),
243            columns,
244        }
245    }
246
247    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
248        ColumnInfo {
249            name: name.to_string(),
250            data_type: udt_name.to_string(),
251            udt_name: udt_name.to_string(),
252            is_nullable: nullable,
253            is_primary_key: false,
254            ordinal_position: 0,
255            schema_name: "public".to_string(),
256            udt_schema: None,
257            column_default: None,
258        }
259    }
260
261    fn gen(table: &TableInfo) -> String {
262        let schema = SchemaInfo::default();
263        let (tokens, _) = generate_struct(
264            table,
265            DatabaseKind::Postgres,
266            &schema,
267            &[],
268            &HashMap::new(),
269            false,
270            TimeCrate::Chrono,
271        );
272        parse_and_format(&tokens).unwrap()
273    }
274
275    fn gen_with(
276        table: &TableInfo,
277        schema: &SchemaInfo,
278        db: DatabaseKind,
279        derives: &[String],
280        overrides: &HashMap<String, String>,
281    ) -> (String, BTreeSet<String>) {
282        let (tokens, imports) = generate_struct(
283            table,
284            db,
285            schema,
286            derives,
287            overrides,
288            false,
289            TimeCrate::Chrono,
290        );
291        (parse_and_format(&tokens).unwrap(), imports)
292    }
293
294    // --- basic structure ---
295
296    #[test]
297    fn test_simple_table() {
298        let table = make_table(
299            "users",
300            vec![
301                make_col("id", "int4", false),
302                make_col("name", "text", false),
303            ],
304        );
305        let code = gen(&table);
306        assert!(code.contains("pub id: i32"));
307        assert!(code.contains("pub name: String"));
308    }
309
310    #[test]
311    fn test_struct_name_pascal_case_and_singular() {
312        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
313        let code = gen(&table);
314        // Plural table → singular struct, snake_case → PascalCase.
315        assert!(
316            code.contains("pub struct UserRole"),
317            "expected singular PascalCase struct name, got:\n{}",
318            code
319        );
320        assert!(!code.contains("pub struct UserRoles"));
321    }
322
323    #[test]
324    fn test_struct_name_is_singular() {
325        let table = make_table("users", vec![make_col("id", "int4", false)]);
326        let code = gen(&table);
327        assert!(
328            code.contains("pub struct User"),
329            "table 'users' must produce singular 'User' struct, got:\n{}",
330            code
331        );
332        assert!(!code.contains("pub struct Users"));
333    }
334
335    #[test]
336    fn test_struct_name_already_singular_unchanged() {
337        let table = make_table("agent_connector", vec![make_col("id", "int4", false)]);
338        let code = gen(&table);
339        assert!(code.contains("pub struct AgentConnector"));
340    }
341
342    #[test]
343    fn test_struct_name_uncountable_unchanged() {
344        let table = make_table("news", vec![make_col("id", "int4", false)]);
345        let code = gen(&table);
346        assert!(code.contains("pub struct News"));
347    }
348
349    #[test]
350    fn test_reserved_keyword_column_prefixed_with_singular_table() {
351        // table "products", column "type" → "product_type", NOT "products_type".
352        let table = make_table(
353            "products",
354            vec![
355                make_col("id", "int4", false),
356                make_col("type", "text", false),
357            ],
358        );
359        let code = gen(&table);
360        assert!(
361            code.contains("pub product_type:"),
362            "expected singularized prefix 'product_type', got:\n{}",
363            code
364        );
365        assert!(
366            !code.contains("pub products_type:"),
367            "must not use plural-form prefix, got:\n{}",
368            code
369        );
370        // The actual SQL column name must still be carried via #[sqlx(rename)].
371        assert!(code.contains("sqlx(rename = \"type\")"));
372    }
373
374    #[test]
375    fn test_reserved_keyword_column_on_already_singular_table() {
376        let table = make_table(
377            "connector",
378            vec![
379                make_col("id", "int4", false),
380                make_col("type", "text", false),
381            ],
382        );
383        let code = gen(&table);
384        assert!(code.contains("pub connector_type:"));
385    }
386
387    // --- nullable ---
388
389    #[test]
390    fn test_nullable_column() {
391        let table = make_table("users", vec![make_col("email", "text", true)]);
392        let code = gen(&table);
393        assert!(code.contains("pub email: Option<String>"));
394    }
395
396    #[test]
397    fn test_non_nullable_column() {
398        let table = make_table("users", vec![make_col("name", "text", false)]);
399        let code = gen(&table);
400        assert!(code.contains("pub name: String"));
401        assert!(!code.contains("Option"));
402    }
403
404    #[test]
405    fn test_mix_nullable() {
406        let table = make_table(
407            "users",
408            vec![make_col("id", "int4", false), make_col("bio", "text", true)],
409        );
410        let code = gen(&table);
411        assert!(code.contains("pub id: i32"));
412        assert!(code.contains("pub bio: Option<String>"));
413    }
414
415    // --- keyword renaming ---
416
417    #[test]
418    fn test_keyword_type_renamed() {
419        let table = make_table("connector", vec![make_col("type", "text", false)]);
420        let code = gen(&table);
421        assert!(code.contains("pub connector_type: String"));
422        assert!(code.contains("sqlx(rename = \"type\")"));
423    }
424
425    #[test]
426    fn test_keyword_fn_renamed() {
427        let table = make_table("item", vec![make_col("fn", "text", false)]);
428        let code = gen(&table);
429        assert!(code.contains("pub item_fn: String"));
430        assert!(code.contains("sqlx(rename = \"fn\")"));
431    }
432
433    #[test]
434    fn test_keyword_match_renamed() {
435        let table = make_table("game", vec![make_col("match", "text", false)]);
436        let code = gen(&table);
437        assert!(code.contains("pub game_match: String"));
438    }
439
440    #[test]
441    fn test_non_keyword_no_rename() {
442        let table = make_table("users", vec![make_col("name", "text", false)]);
443        let code = gen(&table);
444        assert!(!code.contains("sqlx(rename"));
445    }
446
447    // --- snake_case renaming ---
448
449    #[test]
450    fn test_camel_case_column_renamed() {
451        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
452        let code = gen(&table);
453        assert!(code.contains("pub created_at: String"));
454        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
455    }
456
457    #[test]
458    fn test_mixed_case_column_renamed() {
459        let table = make_table("users", vec![make_col("firstName", "text", false)]);
460        let code = gen(&table);
461        assert!(code.contains("pub first_name: String"));
462        assert!(code.contains("sqlx(rename = \"firstName\")"));
463    }
464
465    #[test]
466    fn test_already_snake_case_no_rename() {
467        let table = make_table("users", vec![make_col("created_at", "text", false)]);
468        let code = gen(&table);
469        assert!(code.contains("pub created_at: String"));
470        assert!(!code.contains("sqlx(rename"));
471    }
472
473    // --- derives ---
474
475    #[test]
476    fn test_default_derives() {
477        let table = make_table("users", vec![make_col("id", "int4", false)]);
478        let code = gen(&table);
479        assert!(code.contains("Debug"));
480        assert!(code.contains("Clone"));
481        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
482    }
483
484    #[test]
485    fn test_extra_derive_serialize() {
486        let table = make_table("users", vec![make_col("id", "int4", false)]);
487        let schema = SchemaInfo::default();
488        let derives = vec!["Serialize".to_string()];
489        let (code, _) = gen_with(
490            &table,
491            &schema,
492            DatabaseKind::Postgres,
493            &derives,
494            &HashMap::new(),
495        );
496        assert!(code.contains("Serialize"));
497    }
498
499    #[test]
500    fn test_extra_derives_both_serde() {
501        let table = make_table("users", vec![make_col("id", "int4", false)]);
502        let schema = SchemaInfo::default();
503        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
504        let (_, imports) = gen_with(
505            &table,
506            &schema,
507            DatabaseKind::Postgres,
508            &derives,
509            &HashMap::new(),
510        );
511        assert!(imports.iter().any(|i| i.contains("serde")));
512    }
513
514    // --- imports ---
515
516    #[test]
517    fn test_uuid_import() {
518        let table = make_table("users", vec![make_col("id", "uuid", false)]);
519        let schema = SchemaInfo::default();
520        let (_, imports) = gen_with(
521            &table,
522            &schema,
523            DatabaseKind::Postgres,
524            &[],
525            &HashMap::new(),
526        );
527        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
528    }
529
530    #[test]
531    fn test_timestamptz_import() {
532        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
533        let schema = SchemaInfo::default();
534        let (_, imports) = gen_with(
535            &table,
536            &schema,
537            DatabaseKind::Postgres,
538            &[],
539            &HashMap::new(),
540        );
541        assert!(imports.iter().any(|i| i.contains("chrono")));
542    }
543
544    #[test]
545    fn test_int4_only_serde_import() {
546        let table = make_table("users", vec![make_col("id", "int4", false)]);
547        let schema = SchemaInfo::default();
548        let (_, imports) = gen_with(
549            &table,
550            &schema,
551            DatabaseKind::Postgres,
552            &[],
553            &HashMap::new(),
554        );
555        assert_eq!(imports.len(), 2);
556        assert!(imports.iter().any(|i| i.contains("serde")));
557        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
558    }
559
560    #[test]
561    fn test_multiple_imports_collected() {
562        let table = make_table(
563            "users",
564            vec![
565                make_col("id", "uuid", false),
566                make_col("created_at", "timestamptz", false),
567            ],
568        );
569        let schema = SchemaInfo::default();
570        let (_, imports) = gen_with(
571            &table,
572            &schema,
573            DatabaseKind::Postgres,
574            &[],
575            &HashMap::new(),
576        );
577        assert!(imports.iter().any(|i| i.contains("uuid")));
578        assert!(imports.iter().any(|i| i.contains("chrono")));
579    }
580
581    // --- MySQL enum ---
582
583    #[test]
584    fn test_mysql_enum_column() {
585        let table = make_table(
586            "users",
587            vec![ColumnInfo {
588                name: "status".to_string(),
589                data_type: "enum".to_string(),
590                udt_name: "enum('active','inactive')".to_string(),
591                is_nullable: false,
592                is_primary_key: false,
593                ordinal_position: 0,
594                schema_name: "test_db".to_string(),
595                udt_schema: None,
596                column_default: None,
597            }],
598        );
599        let schema = SchemaInfo::default();
600        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
601        assert!(code.contains("UsersStatus"));
602        assert!(imports.iter().any(|i| i.contains("super::types::")));
603    }
604
605    #[test]
606    fn test_mysql_enum_nullable() {
607        let table = make_table(
608            "users",
609            vec![ColumnInfo {
610                name: "status".to_string(),
611                data_type: "enum".to_string(),
612                udt_name: "enum('a','b')".to_string(),
613                is_nullable: true,
614                is_primary_key: false,
615                ordinal_position: 0,
616                schema_name: "test_db".to_string(),
617                udt_schema: None,
618                column_default: None,
619            }],
620        );
621        let schema = SchemaInfo::default();
622        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
623        assert!(code.contains("Option<UsersStatus>"));
624    }
625
626    // --- type overrides ---
627
628    #[test]
629    fn test_type_override() {
630        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
631        let schema = SchemaInfo::default();
632        let mut overrides = HashMap::new();
633        overrides.insert("jsonb".to_string(), "MyJson".to_string());
634        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
635        assert!(code.contains("pub data: MyJson"));
636    }
637
638    #[test]
639    fn test_type_override_absent() {
640        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
641        let schema = SchemaInfo::default();
642        let (code, _) = gen_with(
643            &table,
644            &schema,
645            DatabaseKind::Postgres,
646            &[],
647            &HashMap::new(),
648        );
649        assert!(code.contains("Value"));
650    }
651
652    #[test]
653    fn test_type_override_nullable() {
654        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
655        let schema = SchemaInfo::default();
656        let mut overrides = HashMap::new();
657        overrides.insert("jsonb".to_string(), "MyJson".to_string());
658        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
659        assert!(code.contains("Option<MyJson>"));
660    }
661
662    // --- native array types ---
663
664    #[test]
665    fn test_native_array_text_gets_sql_type_annotation() {
666        let table = make_table("posts", vec![make_col("tags", "_text", false)]);
667        let code = gen(&table);
668        assert!(code.contains("Vec<String>"));
669        assert!(code.contains("sql_type = \"text\""));
670        assert!(code.contains("is_array"));
671    }
672
673    #[test]
674    fn test_native_array_int4_gets_sql_type_annotation() {
675        let table = make_table("data", vec![make_col("values", "_int4", false)]);
676        let code = gen(&table);
677        assert!(code.contains("Vec<i32>"));
678        assert!(code.contains("sql_type = \"int4\""));
679        assert!(code.contains("is_array"));
680    }
681
682    #[test]
683    fn test_native_array_nullable_gets_sql_type_annotation() {
684        let table = make_table("posts", vec![make_col("tags", "_text", true)]);
685        let code = gen(&table);
686        assert!(code.contains("Option<Vec<String>>"));
687        assert!(code.contains("sql_type = \"text\""));
688        assert!(code.contains("is_array"));
689    }
690
691    #[test]
692    fn test_scalar_builtin_no_sql_type_annotation() {
693        let table = make_table("users", vec![make_col("name", "text", false)]);
694        let code = gen(&table);
695        assert!(code.contains("pub name: String"));
696        assert!(!code.contains("sql_type"));
697    }
698
699    // ========== sanitize_rust_ident ==========
700
701    #[test]
702    fn test_sanitize_replaces_dash() {
703        assert_eq!(sanitize_rust_ident("user-id"), "user_id");
704    }
705
706    #[test]
707    fn test_sanitize_replaces_space() {
708        assert_eq!(sanitize_rust_ident("created at"), "created_at");
709    }
710
711    #[test]
712    fn test_sanitize_replaces_dot() {
713        assert_eq!(sanitize_rust_ident("a.b"), "a_b");
714    }
715
716    #[test]
717    fn test_sanitize_prefixes_leading_digit() {
718        assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
719    }
720
721    #[test]
722    fn test_sanitize_empty_becomes_placeholder() {
723        assert_eq!(sanitize_rust_ident(""), "_field");
724    }
725
726    #[test]
727    fn test_sanitize_leaves_valid_ident_unchanged() {
728        assert_eq!(sanitize_rust_ident("user_id"), "user_id");
729        assert_eq!(sanitize_rust_ident("_private"), "_private");
730    }
731
732    #[test]
733    fn test_column_with_dash_generates_valid_rust() {
734        let table = make_table("users", vec![make_col("user-id", "int4", false)]);
735        let code = gen(&table);
736        // Must produce a Rust-legal identifier; renamed back to the original via #[sqlx(rename)]
737        assert!(
738            code.contains("pub user_id:") || code.contains("user_id:"),
739            "expected sanitized identifier, got:\n{}",
740            code
741        );
742        assert!(code.contains("sqlx(rename = \"user-id\")"));
743    }
744}