Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26    // Build derive list
27    imports.insert("use serde::{Serialize, Deserialize};".to_string());
28    imports.insert("use sqlx_gen::SqlxGen;".to_string());
29    let mut derive_tokens = vec![
30        quote! { Debug },
31        quote! { Clone },
32        quote! { PartialEq },
33        quote! { Eq },
34        quote! { Serialize },
35        quote! { Deserialize },
36        quote! { sqlx::FromRow },
37        quote! { SqlxGen },
38    ];
39    for d in extra_derives {
40        let ident = format_ident!("{}", d);
41        derive_tokens.push(quote! { #ident });
42    }
43
44    // Build fields
45    let fields: Vec<TokenStream> = table
46        .columns
47        .iter()
48        .map(|col| {
49            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50            if let Some(imp) = &rust_type.needs_import {
51                imports.insert(imp.clone());
52            }
53
54            let field_name_snake = col.name.to_snake_case();
55            // If the field name is a Rust keyword, prefix with table name
56            // e.g. column "type" on table "connector" → "connector_type"
57            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58                let prefixed = format!(
59                    "{}_{}",
60                    table.name.to_snake_case(),
61                    field_name_snake
62                );
63                (prefixed, true)
64            } else {
65                let changed = field_name_snake != col.name;
66                (field_name_snake, changed)
67            };
68
69            let field_ident = format_ident!("{}", effective_name);
70            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71                let fallback = format_ident!("String");
72                quote! { #fallback }
73            });
74
75            let rename = if needs_rename {
76                let original = &col.name;
77                quote! { #[sqlx(rename = #original)] }
78            } else {
79                quote! {}
80            };
81
82            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array
83            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
84            let has_pk = col.is_primary_key;
85            let has_sql_type = sql_type.is_some();
86
87            let sqlx_gen_attr = if has_pk || has_sql_type {
88                let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
89                let sql_type_part = match &sql_type {
90                    Some(t) => quote! { sql_type = #t, },
91                    None => quote! {},
92                };
93                let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
94                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part)] }
95            } else {
96                quote! {}
97            };
98
99            quote! {
100                #sqlx_gen_attr
101                #rename
102                pub #field_ident: #type_tokens,
103            }
104        })
105        .collect();
106
107    let table_name_str = &table.name;
108    let schema_name_str = &table.schema_name;
109    let kind_str = if is_view { "view" } else { "table" };
110
111    let tokens = quote! {
112        #[derive(#(#derive_tokens),*)]
113        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
114        pub struct #struct_name {
115            #(#fields)*
116        }
117    };
118
119    (tokens, imports)
120}
121
122/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
123/// SQL type name for casting, plus whether it's an array.
124/// Returns `(Some("schema.type_name"), true)` for arrays of custom types,
125/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
126fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
127    let (base_name, is_array) = match udt_name.strip_prefix('_') {
128        Some(inner) => (inner, true),
129        None => (udt_name, false),
130    };
131
132    // Check enums
133    if let Some(e) = schema_info.enums.iter().find(|e| e.name == base_name) {
134        let qualified = if e.schema_name == "public" {
135            base_name.to_string()
136        } else {
137            format!("{}.{}", e.schema_name, base_name)
138        };
139        return (Some(qualified), is_array);
140    }
141
142    // Check composite types
143    if let Some(c) = schema_info.composite_types.iter().find(|c| c.name == base_name) {
144        let qualified = if c.schema_name == "public" {
145            base_name.to_string()
146        } else {
147            format!("{}.{}", c.schema_name, base_name)
148        };
149        return (Some(qualified), is_array);
150    }
151
152    // Check if this is a non-builtin type that would hit the typemap fallback
153    // (e.g. range types like "timerange", "tsrange", etc.)
154    // Domains resolve to their base type, so they don't need marking.
155    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
156    if !is_domain && !typemap::postgres::is_builtin(base_name) {
157        return (Some(base_name.to_string()), is_array);
158    }
159
160    (None, false)
161}
162
163fn resolve_column_type(
164    col: &crate::introspect::ColumnInfo,
165    db_kind: DatabaseKind,
166    table: &TableInfo,
167    schema_info: &SchemaInfo,
168    type_overrides: &HashMap<String, String>,
169) -> typemap::RustType {
170    // For MySQL ENUM columns, resolve to the generated enum type
171    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
172        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
173        let rt = typemap::RustType::with_import(
174            &enum_type_name,
175            &format!("use super::types::{};", enum_type_name),
176        );
177        return if col.is_nullable {
178            rt.wrap_option()
179        } else {
180            rt
181        };
182    }
183
184    typemap::map_column(col, db_kind, schema_info, type_overrides)
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::codegen::parse_and_format;
191    use crate::introspect::ColumnInfo;
192
193    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
194        TableInfo {
195            schema_name: "public".to_string(),
196            name: name.to_string(),
197            columns,
198        }
199    }
200
201    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
202        ColumnInfo {
203            name: name.to_string(),
204            data_type: udt_name.to_string(),
205            udt_name: udt_name.to_string(),
206            is_nullable: nullable,
207            is_primary_key: false,
208            ordinal_position: 0,
209            schema_name: "public".to_string(),
210            column_default: None,
211        }
212    }
213
214    fn gen(table: &TableInfo) -> String {
215        let schema = SchemaInfo::default();
216        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
217        parse_and_format(&tokens)
218    }
219
220    fn gen_with(
221        table: &TableInfo,
222        schema: &SchemaInfo,
223        db: DatabaseKind,
224        derives: &[String],
225        overrides: &HashMap<String, String>,
226    ) -> (String, BTreeSet<String>) {
227        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
228        (parse_and_format(&tokens), imports)
229    }
230
231    // --- basic structure ---
232
233    #[test]
234    fn test_simple_table() {
235        let table = make_table("users", vec![
236            make_col("id", "int4", false),
237            make_col("name", "text", false),
238        ]);
239        let code = gen(&table);
240        assert!(code.contains("pub id: i32"));
241        assert!(code.contains("pub name: String"));
242    }
243
244    #[test]
245    fn test_struct_name_pascal_case() {
246        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
247        let code = gen(&table);
248        assert!(code.contains("pub struct UserRoles"));
249    }
250
251    #[test]
252    fn test_struct_name_simple() {
253        let table = make_table("users", vec![make_col("id", "int4", false)]);
254        let code = gen(&table);
255        assert!(code.contains("pub struct Users"));
256    }
257
258    // --- nullable ---
259
260    #[test]
261    fn test_nullable_column() {
262        let table = make_table("users", vec![make_col("email", "text", true)]);
263        let code = gen(&table);
264        assert!(code.contains("pub email: Option<String>"));
265    }
266
267    #[test]
268    fn test_non_nullable_column() {
269        let table = make_table("users", vec![make_col("name", "text", false)]);
270        let code = gen(&table);
271        assert!(code.contains("pub name: String"));
272        assert!(!code.contains("Option"));
273    }
274
275    #[test]
276    fn test_mix_nullable() {
277        let table = make_table("users", vec![
278            make_col("id", "int4", false),
279            make_col("bio", "text", true),
280        ]);
281        let code = gen(&table);
282        assert!(code.contains("pub id: i32"));
283        assert!(code.contains("pub bio: Option<String>"));
284    }
285
286    // --- keyword renaming ---
287
288    #[test]
289    fn test_keyword_type_renamed() {
290        let table = make_table("connector", vec![make_col("type", "text", false)]);
291        let code = gen(&table);
292        assert!(code.contains("pub connector_type: String"));
293        assert!(code.contains("sqlx(rename = \"type\")"));
294    }
295
296    #[test]
297    fn test_keyword_fn_renamed() {
298        let table = make_table("item", vec![make_col("fn", "text", false)]);
299        let code = gen(&table);
300        assert!(code.contains("pub item_fn: String"));
301        assert!(code.contains("sqlx(rename = \"fn\")"));
302    }
303
304    #[test]
305    fn test_keyword_match_renamed() {
306        let table = make_table("game", vec![make_col("match", "text", false)]);
307        let code = gen(&table);
308        assert!(code.contains("pub game_match: String"));
309    }
310
311    #[test]
312    fn test_non_keyword_no_rename() {
313        let table = make_table("users", vec![make_col("name", "text", false)]);
314        let code = gen(&table);
315        assert!(!code.contains("sqlx(rename"));
316    }
317
318    // --- snake_case renaming ---
319
320    #[test]
321    fn test_camel_case_column_renamed() {
322        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
323        let code = gen(&table);
324        assert!(code.contains("pub created_at: String"));
325        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
326    }
327
328    #[test]
329    fn test_mixed_case_column_renamed() {
330        let table = make_table("users", vec![make_col("firstName", "text", false)]);
331        let code = gen(&table);
332        assert!(code.contains("pub first_name: String"));
333        assert!(code.contains("sqlx(rename = \"firstName\")"));
334    }
335
336    #[test]
337    fn test_already_snake_case_no_rename() {
338        let table = make_table("users", vec![make_col("created_at", "text", false)]);
339        let code = gen(&table);
340        assert!(code.contains("pub created_at: String"));
341        assert!(!code.contains("sqlx(rename"));
342    }
343
344    // --- derives ---
345
346    #[test]
347    fn test_default_derives() {
348        let table = make_table("users", vec![make_col("id", "int4", false)]);
349        let code = gen(&table);
350        assert!(code.contains("Debug"));
351        assert!(code.contains("Clone"));
352        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
353    }
354
355    #[test]
356    fn test_extra_derive_serialize() {
357        let table = make_table("users", vec![make_col("id", "int4", false)]);
358        let schema = SchemaInfo::default();
359        let derives = vec!["Serialize".to_string()];
360        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
361        assert!(code.contains("Serialize"));
362    }
363
364    #[test]
365    fn test_extra_derives_both_serde() {
366        let table = make_table("users", vec![make_col("id", "int4", false)]);
367        let schema = SchemaInfo::default();
368        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
369        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
370        assert!(imports.iter().any(|i| i.contains("serde")));
371    }
372
373    // --- imports ---
374
375    #[test]
376    fn test_uuid_import() {
377        let table = make_table("users", vec![make_col("id", "uuid", false)]);
378        let schema = SchemaInfo::default();
379        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
380        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
381    }
382
383    #[test]
384    fn test_timestamptz_import() {
385        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
386        let schema = SchemaInfo::default();
387        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
388        assert!(imports.iter().any(|i| i.contains("chrono")));
389    }
390
391    #[test]
392    fn test_int4_only_serde_import() {
393        let table = make_table("users", vec![make_col("id", "int4", false)]);
394        let schema = SchemaInfo::default();
395        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
396        assert_eq!(imports.len(), 2);
397        assert!(imports.iter().any(|i| i.contains("serde")));
398        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
399    }
400
401    #[test]
402    fn test_multiple_imports_collected() {
403        let table = make_table("users", vec![
404            make_col("id", "uuid", false),
405            make_col("created_at", "timestamptz", false),
406        ]);
407        let schema = SchemaInfo::default();
408        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
409        assert!(imports.iter().any(|i| i.contains("uuid")));
410        assert!(imports.iter().any(|i| i.contains("chrono")));
411    }
412
413    // --- MySQL enum ---
414
415    #[test]
416    fn test_mysql_enum_column() {
417        let table = make_table("users", vec![ColumnInfo {
418            name: "status".to_string(),
419            data_type: "enum".to_string(),
420            udt_name: "enum('active','inactive')".to_string(),
421            is_nullable: false,
422            is_primary_key: false,
423            ordinal_position: 0,
424            schema_name: "test_db".to_string(),
425            column_default: None,
426        }]);
427        let schema = SchemaInfo::default();
428        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
429        assert!(code.contains("UsersStatus"));
430        assert!(imports.iter().any(|i| i.contains("super::types::")));
431    }
432
433    #[test]
434    fn test_mysql_enum_nullable() {
435        let table = make_table("users", vec![ColumnInfo {
436            name: "status".to_string(),
437            data_type: "enum".to_string(),
438            udt_name: "enum('a','b')".to_string(),
439            is_nullable: true,
440            is_primary_key: false,
441            ordinal_position: 0,
442            schema_name: "test_db".to_string(),
443            column_default: None,
444        }]);
445        let schema = SchemaInfo::default();
446        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
447        assert!(code.contains("Option<UsersStatus>"));
448    }
449
450    // --- type overrides ---
451
452    #[test]
453    fn test_type_override() {
454        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
455        let schema = SchemaInfo::default();
456        let mut overrides = HashMap::new();
457        overrides.insert("jsonb".to_string(), "MyJson".to_string());
458        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
459        assert!(code.contains("pub data: MyJson"));
460    }
461
462    #[test]
463    fn test_type_override_absent() {
464        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
465        let schema = SchemaInfo::default();
466        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
467        assert!(code.contains("Value"));
468    }
469
470    #[test]
471    fn test_type_override_nullable() {
472        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
473        let schema = SchemaInfo::default();
474        let mut overrides = HashMap::new();
475        overrides.insert("jsonb".to_string(), "MyJson".to_string());
476        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
477        assert!(code.contains("Option<MyJson>"));
478    }
479}