Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword};
10use crate::introspect::{SchemaInfo, TableInfo};
11use crate::typemap;
12
13pub fn generate_struct(
14    table: &TableInfo,
15    db_kind: DatabaseKind,
16    schema_info: &SchemaInfo,
17    extra_derives: &[String],
18    type_overrides: &HashMap<String, String>,
19    is_view: bool,
20    time_crate: TimeCrate,
21) -> (TokenStream, BTreeSet<String>) {
22    let mut imports = BTreeSet::new();
23    for imp in imports_for_derives(extra_derives) {
24        imports.insert(imp);
25    }
26    // Tables are conventionally plural ("users"), structs singular ("User"),
27    // matching the Rust ORM ecosystem (Diesel, SeaORM) and ActiveRecord.
28    let struct_name = format_ident!("{}", singularize(&table.name).to_upper_camel_case());
29
30    // Build derive list
31    imports.insert("use serde::{Serialize, Deserialize};".to_string());
32    imports.insert("use sqlx_gen::SqlxGen;".to_string());
33    let mut derive_tokens = vec![
34        quote! { Debug },
35        quote! { Clone },
36        quote! { PartialEq },
37        quote! { Eq },
38        quote! { Serialize },
39        quote! { Deserialize },
40        quote! { sqlx::FromRow },
41        quote! { SqlxGen },
42    ];
43    for d in extra_derives {
44        let ident = format_ident!("{}", d);
45        derive_tokens.push(quote! { #ident });
46    }
47
48    // Build fields
49    let fields: Vec<TokenStream> = table
50        .columns
51        .iter()
52        .map(|col| {
53            let rust_type =
54                resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
55            if let Some(imp) = &rust_type.needs_import {
56                imports.insert(imp.clone());
57            }
58
59            let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
60            // If the field name is a Rust keyword, prefix with table name
61            // e.g. column "type" on table "connector" → "connector_type"
62            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
63                let prefixed = format!("{}_{}", table.name.to_snake_case(), field_name_snake);
64                (prefixed, true)
65            } else {
66                let changed = field_name_snake != col.name;
67                (field_name_snake, changed)
68            };
69
70            let field_ident = format_ident!("{}", effective_name);
71            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72                let fallback = format_ident!("String");
73                quote! { #fallback }
74            });
75
76            let rename = if needs_rename {
77                let original = &col.name;
78                quote! { #[sqlx(rename = #original)] }
79            } else {
80                quote! {}
81            };
82
83            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array, column_default
84            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85            let has_pk = col.is_primary_key;
86            let has_sql_type = sql_type.is_some();
87            let has_default = col.column_default.is_some();
88
89            let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
90                let pk_part = if has_pk {
91                    quote! { primary_key, }
92                } else {
93                    quote! {}
94                };
95                let sql_type_part = match &sql_type {
96                    Some(t) => quote! { sql_type = #t, },
97                    None => quote! {},
98                };
99                let array_part = if is_sql_array {
100                    quote! { is_array, }
101                } else {
102                    quote! {}
103                };
104                let default_part = match &col.column_default {
105                    Some(d) => quote! { column_default = #d, },
106                    None => quote! {},
107                };
108                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
109            } else {
110                quote! {}
111            };
112
113            quote! {
114                #sqlx_gen_attr
115                #rename
116                pub #field_ident: #type_tokens,
117            }
118        })
119        .collect();
120
121    let table_name_str = &table.name;
122    let schema_name_str = &table.schema_name;
123    let kind_str = if is_view { "view" } else { "table" };
124
125    let tokens = quote! {
126        #[derive(#(#derive_tokens),*)]
127        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
128        pub struct #struct_name {
129            #(#fields)*
130        }
131    };
132
133    (tokens, imports)
134}
135
136/// Sanitize a candidate Rust identifier:
137/// - replace any character that is not ascii-alphanumeric or '_' with '_'
138/// - prefix with '_' if the result starts with a digit
139/// - fall back to "_field" if the input is empty
140///
141/// Lets sqlx-gen survive columns named `user-id`, `created at`, `123`, etc.
142/// — they still need a `#[sqlx(rename = "<original>")]` to roundtrip the DB
143/// column, which the caller handles via the `changed` flag.
144pub(crate) fn sanitize_rust_ident(name: &str) -> String {
145    if name.is_empty() {
146        return "_field".to_string();
147    }
148    let mut out: String = name
149        .chars()
150        .map(|c| {
151            if c.is_ascii_alphanumeric() || c == '_' {
152                c
153            } else {
154                '_'
155            }
156        })
157        .collect();
158    if out.starts_with(|c: char| c.is_ascii_digit()) {
159        out.insert(0, '_');
160    }
161    out
162}
163
164/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
165/// SQL type name for casting, plus whether it's an array.
166/// Returns `(Some("type_name"), true)` for arrays of custom types,
167/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
168fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
169    let (base_name, is_array) = match udt_name.strip_prefix('_') {
170        Some(inner) => (inner, true),
171        None => (udt_name, false),
172    };
173
174    // Check enums
175    if schema_info.enums.iter().any(|e| e.name == base_name) {
176        return (Some(base_name.to_string()), is_array);
177    }
178
179    // Check composite types
180    if schema_info
181        .composite_types
182        .iter()
183        .any(|c| c.name == base_name)
184    {
185        return (Some(base_name.to_string()), is_array);
186    }
187
188    // Check if this is a non-builtin type that would hit the typemap fallback
189    // (e.g. range types like "timerange", "tsrange", etc.)
190    // Domains resolve to their base type, so they don't need marking.
191    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
192    if !is_domain && !typemap::postgres::is_builtin(base_name) {
193        return (Some(base_name.to_string()), is_array);
194    }
195
196    // Native arrays of builtin types (e.g. `_text` → `text[]`) need sql_type annotation
197    // so that codegen falls back to runtime mode — `query_as!` macro doesn't support
198    // `Vec<T>` without `PgHasArrayType`.
199    if is_array {
200        return (Some(base_name.to_string()), true);
201    }
202
203    (None, false)
204}
205
206fn resolve_column_type(
207    col: &crate::introspect::ColumnInfo,
208    db_kind: DatabaseKind,
209    table: &TableInfo,
210    schema_info: &SchemaInfo,
211    type_overrides: &HashMap<String, String>,
212    time_crate: TimeCrate,
213) -> typemap::RustType {
214    // For MySQL ENUM columns, resolve to the generated enum type
215    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
216        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
217        let rt = typemap::RustType::with_import(
218            &enum_type_name,
219            &format!("use super::types::{};", enum_type_name),
220        );
221        return if col.is_nullable {
222            rt.wrap_option()
223        } else {
224            rt
225        };
226    }
227
228    typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::codegen::parse_and_format;
235    use crate::introspect::ColumnInfo;
236
237    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
238        TableInfo {
239            schema_name: "public".to_string(),
240            name: name.to_string(),
241            columns,
242        }
243    }
244
245    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
246        ColumnInfo {
247            name: name.to_string(),
248            data_type: udt_name.to_string(),
249            udt_name: udt_name.to_string(),
250            is_nullable: nullable,
251            is_primary_key: false,
252            ordinal_position: 0,
253            schema_name: "public".to_string(),
254            udt_schema: None,
255            column_default: None,
256        }
257    }
258
259    fn gen(table: &TableInfo) -> String {
260        let schema = SchemaInfo::default();
261        let (tokens, _) = generate_struct(
262            table,
263            DatabaseKind::Postgres,
264            &schema,
265            &[],
266            &HashMap::new(),
267            false,
268            TimeCrate::Chrono,
269        );
270        parse_and_format(&tokens).unwrap()
271    }
272
273    fn gen_with(
274        table: &TableInfo,
275        schema: &SchemaInfo,
276        db: DatabaseKind,
277        derives: &[String],
278        overrides: &HashMap<String, String>,
279    ) -> (String, BTreeSet<String>) {
280        let (tokens, imports) = generate_struct(
281            table,
282            db,
283            schema,
284            derives,
285            overrides,
286            false,
287            TimeCrate::Chrono,
288        );
289        (parse_and_format(&tokens).unwrap(), imports)
290    }
291
292    // --- basic structure ---
293
294    #[test]
295    fn test_simple_table() {
296        let table = make_table(
297            "users",
298            vec![
299                make_col("id", "int4", false),
300                make_col("name", "text", false),
301            ],
302        );
303        let code = gen(&table);
304        assert!(code.contains("pub id: i32"));
305        assert!(code.contains("pub name: String"));
306    }
307
308    #[test]
309    fn test_struct_name_pascal_case_and_singular() {
310        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
311        let code = gen(&table);
312        // Plural table → singular struct, snake_case → PascalCase.
313        assert!(
314            code.contains("pub struct UserRole"),
315            "expected singular PascalCase struct name, got:\n{}",
316            code
317        );
318        assert!(!code.contains("pub struct UserRoles"));
319    }
320
321    #[test]
322    fn test_struct_name_is_singular() {
323        let table = make_table("users", vec![make_col("id", "int4", false)]);
324        let code = gen(&table);
325        assert!(
326            code.contains("pub struct User"),
327            "table 'users' must produce singular 'User' struct, got:\n{}",
328            code
329        );
330        assert!(!code.contains("pub struct Users"));
331    }
332
333    #[test]
334    fn test_struct_name_already_singular_unchanged() {
335        let table = make_table("agent_connector", vec![make_col("id", "int4", false)]);
336        let code = gen(&table);
337        assert!(code.contains("pub struct AgentConnector"));
338    }
339
340    #[test]
341    fn test_struct_name_uncountable_unchanged() {
342        let table = make_table("news", vec![make_col("id", "int4", false)]);
343        let code = gen(&table);
344        assert!(code.contains("pub struct News"));
345    }
346
347    // --- nullable ---
348
349    #[test]
350    fn test_nullable_column() {
351        let table = make_table("users", vec![make_col("email", "text", true)]);
352        let code = gen(&table);
353        assert!(code.contains("pub email: Option<String>"));
354    }
355
356    #[test]
357    fn test_non_nullable_column() {
358        let table = make_table("users", vec![make_col("name", "text", false)]);
359        let code = gen(&table);
360        assert!(code.contains("pub name: String"));
361        assert!(!code.contains("Option"));
362    }
363
364    #[test]
365    fn test_mix_nullable() {
366        let table = make_table(
367            "users",
368            vec![make_col("id", "int4", false), make_col("bio", "text", true)],
369        );
370        let code = gen(&table);
371        assert!(code.contains("pub id: i32"));
372        assert!(code.contains("pub bio: Option<String>"));
373    }
374
375    // --- keyword renaming ---
376
377    #[test]
378    fn test_keyword_type_renamed() {
379        let table = make_table("connector", vec![make_col("type", "text", false)]);
380        let code = gen(&table);
381        assert!(code.contains("pub connector_type: String"));
382        assert!(code.contains("sqlx(rename = \"type\")"));
383    }
384
385    #[test]
386    fn test_keyword_fn_renamed() {
387        let table = make_table("item", vec![make_col("fn", "text", false)]);
388        let code = gen(&table);
389        assert!(code.contains("pub item_fn: String"));
390        assert!(code.contains("sqlx(rename = \"fn\")"));
391    }
392
393    #[test]
394    fn test_keyword_match_renamed() {
395        let table = make_table("game", vec![make_col("match", "text", false)]);
396        let code = gen(&table);
397        assert!(code.contains("pub game_match: String"));
398    }
399
400    #[test]
401    fn test_non_keyword_no_rename() {
402        let table = make_table("users", vec![make_col("name", "text", false)]);
403        let code = gen(&table);
404        assert!(!code.contains("sqlx(rename"));
405    }
406
407    // --- snake_case renaming ---
408
409    #[test]
410    fn test_camel_case_column_renamed() {
411        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
412        let code = gen(&table);
413        assert!(code.contains("pub created_at: String"));
414        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
415    }
416
417    #[test]
418    fn test_mixed_case_column_renamed() {
419        let table = make_table("users", vec![make_col("firstName", "text", false)]);
420        let code = gen(&table);
421        assert!(code.contains("pub first_name: String"));
422        assert!(code.contains("sqlx(rename = \"firstName\")"));
423    }
424
425    #[test]
426    fn test_already_snake_case_no_rename() {
427        let table = make_table("users", vec![make_col("created_at", "text", false)]);
428        let code = gen(&table);
429        assert!(code.contains("pub created_at: String"));
430        assert!(!code.contains("sqlx(rename"));
431    }
432
433    // --- derives ---
434
435    #[test]
436    fn test_default_derives() {
437        let table = make_table("users", vec![make_col("id", "int4", false)]);
438        let code = gen(&table);
439        assert!(code.contains("Debug"));
440        assert!(code.contains("Clone"));
441        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
442    }
443
444    #[test]
445    fn test_extra_derive_serialize() {
446        let table = make_table("users", vec![make_col("id", "int4", false)]);
447        let schema = SchemaInfo::default();
448        let derives = vec!["Serialize".to_string()];
449        let (code, _) = gen_with(
450            &table,
451            &schema,
452            DatabaseKind::Postgres,
453            &derives,
454            &HashMap::new(),
455        );
456        assert!(code.contains("Serialize"));
457    }
458
459    #[test]
460    fn test_extra_derives_both_serde() {
461        let table = make_table("users", vec![make_col("id", "int4", false)]);
462        let schema = SchemaInfo::default();
463        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
464        let (_, imports) = gen_with(
465            &table,
466            &schema,
467            DatabaseKind::Postgres,
468            &derives,
469            &HashMap::new(),
470        );
471        assert!(imports.iter().any(|i| i.contains("serde")));
472    }
473
474    // --- imports ---
475
476    #[test]
477    fn test_uuid_import() {
478        let table = make_table("users", vec![make_col("id", "uuid", false)]);
479        let schema = SchemaInfo::default();
480        let (_, imports) = gen_with(
481            &table,
482            &schema,
483            DatabaseKind::Postgres,
484            &[],
485            &HashMap::new(),
486        );
487        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
488    }
489
490    #[test]
491    fn test_timestamptz_import() {
492        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
493        let schema = SchemaInfo::default();
494        let (_, imports) = gen_with(
495            &table,
496            &schema,
497            DatabaseKind::Postgres,
498            &[],
499            &HashMap::new(),
500        );
501        assert!(imports.iter().any(|i| i.contains("chrono")));
502    }
503
504    #[test]
505    fn test_int4_only_serde_import() {
506        let table = make_table("users", vec![make_col("id", "int4", false)]);
507        let schema = SchemaInfo::default();
508        let (_, imports) = gen_with(
509            &table,
510            &schema,
511            DatabaseKind::Postgres,
512            &[],
513            &HashMap::new(),
514        );
515        assert_eq!(imports.len(), 2);
516        assert!(imports.iter().any(|i| i.contains("serde")));
517        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
518    }
519
520    #[test]
521    fn test_multiple_imports_collected() {
522        let table = make_table(
523            "users",
524            vec![
525                make_col("id", "uuid", false),
526                make_col("created_at", "timestamptz", false),
527            ],
528        );
529        let schema = SchemaInfo::default();
530        let (_, imports) = gen_with(
531            &table,
532            &schema,
533            DatabaseKind::Postgres,
534            &[],
535            &HashMap::new(),
536        );
537        assert!(imports.iter().any(|i| i.contains("uuid")));
538        assert!(imports.iter().any(|i| i.contains("chrono")));
539    }
540
541    // --- MySQL enum ---
542
543    #[test]
544    fn test_mysql_enum_column() {
545        let table = make_table(
546            "users",
547            vec![ColumnInfo {
548                name: "status".to_string(),
549                data_type: "enum".to_string(),
550                udt_name: "enum('active','inactive')".to_string(),
551                is_nullable: false,
552                is_primary_key: false,
553                ordinal_position: 0,
554                schema_name: "test_db".to_string(),
555                udt_schema: None,
556                column_default: None,
557            }],
558        );
559        let schema = SchemaInfo::default();
560        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
561        assert!(code.contains("UsersStatus"));
562        assert!(imports.iter().any(|i| i.contains("super::types::")));
563    }
564
565    #[test]
566    fn test_mysql_enum_nullable() {
567        let table = make_table(
568            "users",
569            vec![ColumnInfo {
570                name: "status".to_string(),
571                data_type: "enum".to_string(),
572                udt_name: "enum('a','b')".to_string(),
573                is_nullable: true,
574                is_primary_key: false,
575                ordinal_position: 0,
576                schema_name: "test_db".to_string(),
577                udt_schema: None,
578                column_default: None,
579            }],
580        );
581        let schema = SchemaInfo::default();
582        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
583        assert!(code.contains("Option<UsersStatus>"));
584    }
585
586    // --- type overrides ---
587
588    #[test]
589    fn test_type_override() {
590        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
591        let schema = SchemaInfo::default();
592        let mut overrides = HashMap::new();
593        overrides.insert("jsonb".to_string(), "MyJson".to_string());
594        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
595        assert!(code.contains("pub data: MyJson"));
596    }
597
598    #[test]
599    fn test_type_override_absent() {
600        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
601        let schema = SchemaInfo::default();
602        let (code, _) = gen_with(
603            &table,
604            &schema,
605            DatabaseKind::Postgres,
606            &[],
607            &HashMap::new(),
608        );
609        assert!(code.contains("Value"));
610    }
611
612    #[test]
613    fn test_type_override_nullable() {
614        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
615        let schema = SchemaInfo::default();
616        let mut overrides = HashMap::new();
617        overrides.insert("jsonb".to_string(), "MyJson".to_string());
618        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
619        assert!(code.contains("Option<MyJson>"));
620    }
621
622    // --- native array types ---
623
624    #[test]
625    fn test_native_array_text_gets_sql_type_annotation() {
626        let table = make_table("posts", vec![make_col("tags", "_text", false)]);
627        let code = gen(&table);
628        assert!(code.contains("Vec<String>"));
629        assert!(code.contains("sql_type = \"text\""));
630        assert!(code.contains("is_array"));
631    }
632
633    #[test]
634    fn test_native_array_int4_gets_sql_type_annotation() {
635        let table = make_table("data", vec![make_col("values", "_int4", false)]);
636        let code = gen(&table);
637        assert!(code.contains("Vec<i32>"));
638        assert!(code.contains("sql_type = \"int4\""));
639        assert!(code.contains("is_array"));
640    }
641
642    #[test]
643    fn test_native_array_nullable_gets_sql_type_annotation() {
644        let table = make_table("posts", vec![make_col("tags", "_text", true)]);
645        let code = gen(&table);
646        assert!(code.contains("Option<Vec<String>>"));
647        assert!(code.contains("sql_type = \"text\""));
648        assert!(code.contains("is_array"));
649    }
650
651    #[test]
652    fn test_scalar_builtin_no_sql_type_annotation() {
653        let table = make_table("users", vec![make_col("name", "text", false)]);
654        let code = gen(&table);
655        assert!(code.contains("pub name: String"));
656        assert!(!code.contains("sql_type"));
657    }
658
659    // ========== sanitize_rust_ident ==========
660
661    #[test]
662    fn test_sanitize_replaces_dash() {
663        assert_eq!(sanitize_rust_ident("user-id"), "user_id");
664    }
665
666    #[test]
667    fn test_sanitize_replaces_space() {
668        assert_eq!(sanitize_rust_ident("created at"), "created_at");
669    }
670
671    #[test]
672    fn test_sanitize_replaces_dot() {
673        assert_eq!(sanitize_rust_ident("a.b"), "a_b");
674    }
675
676    #[test]
677    fn test_sanitize_prefixes_leading_digit() {
678        assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
679    }
680
681    #[test]
682    fn test_sanitize_empty_becomes_placeholder() {
683        assert_eq!(sanitize_rust_ident(""), "_field");
684    }
685
686    #[test]
687    fn test_sanitize_leaves_valid_ident_unchanged() {
688        assert_eq!(sanitize_rust_ident("user_id"), "user_id");
689        assert_eq!(sanitize_rust_ident("_private"), "_private");
690    }
691
692    #[test]
693    fn test_column_with_dash_generates_valid_rust() {
694        let table = make_table("users", vec![make_col("user-id", "int4", false)]);
695        let code = gen(&table);
696        // Must produce a Rust-legal identifier; renamed back to the original via #[sqlx(rename)]
697        assert!(
698            code.contains("pub user_id:") || code.contains("user_id:"),
699            "expected sanitized identifier, got:\n{}",
700            code
701        );
702        assert!(code.contains("sqlx(rename = \"user-id\")"));
703    }
704}