Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26    // Build derive list
27    imports.insert("use serde::{Serialize, Deserialize};".to_string());
28    imports.insert("use sqlx_gen::SqlxGen;".to_string());
29    let mut derive_tokens = vec![
30        quote! { Debug },
31        quote! { Clone },
32        quote! { PartialEq },
33        quote! { Eq },
34        quote! { Serialize },
35        quote! { Deserialize },
36        quote! { sqlx::FromRow },
37        quote! { SqlxGen },
38    ];
39    for d in extra_derives {
40        let ident = format_ident!("{}", d);
41        derive_tokens.push(quote! { #ident });
42    }
43
44    // Build fields
45    let fields: Vec<TokenStream> = table
46        .columns
47        .iter()
48        .map(|col| {
49            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50            if let Some(imp) = &rust_type.needs_import {
51                imports.insert(imp.clone());
52            }
53
54            let field_name_snake = col.name.to_snake_case();
55            // If the field name is a Rust keyword, prefix with table name
56            // e.g. column "type" on table "connector" → "connector_type"
57            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58                let prefixed = format!(
59                    "{}_{}",
60                    table.name.to_snake_case(),
61                    field_name_snake
62                );
63                (prefixed, true)
64            } else {
65                let changed = field_name_snake != col.name;
66                (field_name_snake, changed)
67            };
68
69            let field_ident = format_ident!("{}", effective_name);
70            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71                let fallback = format_ident!("String");
72                quote! { #fallback }
73            });
74
75            let rename = if needs_rename {
76                let original = &col.name;
77                quote! { #[sqlx(rename = #original)] }
78            } else {
79                quote! {}
80            };
81
82            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array
83            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
84            let has_pk = col.is_primary_key;
85            let has_sql_type = sql_type.is_some();
86
87            let sqlx_gen_attr = if has_pk || has_sql_type {
88                let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
89                let sql_type_part = match &sql_type {
90                    Some(t) => quote! { sql_type = #t, },
91                    None => quote! {},
92                };
93                let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
94                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part)] }
95            } else {
96                quote! {}
97            };
98
99            quote! {
100                #sqlx_gen_attr
101                #rename
102                pub #field_ident: #type_tokens,
103            }
104        })
105        .collect();
106
107    let table_name_str = &table.name;
108    let schema_name_str = &table.schema_name;
109    let kind_str = if is_view { "view" } else { "table" };
110
111    let tokens = quote! {
112        #[derive(#(#derive_tokens),*)]
113        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
114        pub struct #struct_name {
115            #(#fields)*
116        }
117    };
118
119    (tokens, imports)
120}
121
122/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
123/// SQL type name for casting, plus whether it's an array.
124/// Returns `(Some("type_name"), true)` for arrays of custom types,
125/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
126fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
127    let (base_name, is_array) = match udt_name.strip_prefix('_') {
128        Some(inner) => (inner, true),
129        None => (udt_name, false),
130    };
131
132    // Check enums
133    if schema_info.enums.iter().any(|e| e.name == base_name) {
134        return (Some(base_name.to_string()), is_array);
135    }
136
137    // Check composite types
138    if schema_info.composite_types.iter().any(|c| c.name == base_name) {
139        return (Some(base_name.to_string()), is_array);
140    }
141
142    // Check if this is a non-builtin type that would hit the typemap fallback
143    // (e.g. range types like "timerange", "tsrange", etc.)
144    // Domains resolve to their base type, so they don't need marking.
145    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
146    if !is_domain && !typemap::postgres::is_builtin(base_name) {
147        return (Some(base_name.to_string()), is_array);
148    }
149
150    (None, false)
151}
152
153fn resolve_column_type(
154    col: &crate::introspect::ColumnInfo,
155    db_kind: DatabaseKind,
156    table: &TableInfo,
157    schema_info: &SchemaInfo,
158    type_overrides: &HashMap<String, String>,
159) -> typemap::RustType {
160    // For MySQL ENUM columns, resolve to the generated enum type
161    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
162        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
163        let rt = typemap::RustType::with_import(
164            &enum_type_name,
165            &format!("use super::types::{};", enum_type_name),
166        );
167        return if col.is_nullable {
168            rt.wrap_option()
169        } else {
170            rt
171        };
172    }
173
174    typemap::map_column(col, db_kind, schema_info, type_overrides)
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::codegen::parse_and_format;
181    use crate::introspect::ColumnInfo;
182
183    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
184        TableInfo {
185            schema_name: "public".to_string(),
186            name: name.to_string(),
187            columns,
188        }
189    }
190
191    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
192        ColumnInfo {
193            name: name.to_string(),
194            data_type: udt_name.to_string(),
195            udt_name: udt_name.to_string(),
196            is_nullable: nullable,
197            is_primary_key: false,
198            ordinal_position: 0,
199            schema_name: "public".to_string(),
200            column_default: None,
201        }
202    }
203
204    fn gen(table: &TableInfo) -> String {
205        let schema = SchemaInfo::default();
206        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
207        parse_and_format(&tokens)
208    }
209
210    fn gen_with(
211        table: &TableInfo,
212        schema: &SchemaInfo,
213        db: DatabaseKind,
214        derives: &[String],
215        overrides: &HashMap<String, String>,
216    ) -> (String, BTreeSet<String>) {
217        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
218        (parse_and_format(&tokens), imports)
219    }
220
221    // --- basic structure ---
222
223    #[test]
224    fn test_simple_table() {
225        let table = make_table("users", vec![
226            make_col("id", "int4", false),
227            make_col("name", "text", false),
228        ]);
229        let code = gen(&table);
230        assert!(code.contains("pub id: i32"));
231        assert!(code.contains("pub name: String"));
232    }
233
234    #[test]
235    fn test_struct_name_pascal_case() {
236        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
237        let code = gen(&table);
238        assert!(code.contains("pub struct UserRoles"));
239    }
240
241    #[test]
242    fn test_struct_name_simple() {
243        let table = make_table("users", vec![make_col("id", "int4", false)]);
244        let code = gen(&table);
245        assert!(code.contains("pub struct Users"));
246    }
247
248    // --- nullable ---
249
250    #[test]
251    fn test_nullable_column() {
252        let table = make_table("users", vec![make_col("email", "text", true)]);
253        let code = gen(&table);
254        assert!(code.contains("pub email: Option<String>"));
255    }
256
257    #[test]
258    fn test_non_nullable_column() {
259        let table = make_table("users", vec![make_col("name", "text", false)]);
260        let code = gen(&table);
261        assert!(code.contains("pub name: String"));
262        assert!(!code.contains("Option"));
263    }
264
265    #[test]
266    fn test_mix_nullable() {
267        let table = make_table("users", vec![
268            make_col("id", "int4", false),
269            make_col("bio", "text", true),
270        ]);
271        let code = gen(&table);
272        assert!(code.contains("pub id: i32"));
273        assert!(code.contains("pub bio: Option<String>"));
274    }
275
276    // --- keyword renaming ---
277
278    #[test]
279    fn test_keyword_type_renamed() {
280        let table = make_table("connector", vec![make_col("type", "text", false)]);
281        let code = gen(&table);
282        assert!(code.contains("pub connector_type: String"));
283        assert!(code.contains("sqlx(rename = \"type\")"));
284    }
285
286    #[test]
287    fn test_keyword_fn_renamed() {
288        let table = make_table("item", vec![make_col("fn", "text", false)]);
289        let code = gen(&table);
290        assert!(code.contains("pub item_fn: String"));
291        assert!(code.contains("sqlx(rename = \"fn\")"));
292    }
293
294    #[test]
295    fn test_keyword_match_renamed() {
296        let table = make_table("game", vec![make_col("match", "text", false)]);
297        let code = gen(&table);
298        assert!(code.contains("pub game_match: String"));
299    }
300
301    #[test]
302    fn test_non_keyword_no_rename() {
303        let table = make_table("users", vec![make_col("name", "text", false)]);
304        let code = gen(&table);
305        assert!(!code.contains("sqlx(rename"));
306    }
307
308    // --- snake_case renaming ---
309
310    #[test]
311    fn test_camel_case_column_renamed() {
312        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
313        let code = gen(&table);
314        assert!(code.contains("pub created_at: String"));
315        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
316    }
317
318    #[test]
319    fn test_mixed_case_column_renamed() {
320        let table = make_table("users", vec![make_col("firstName", "text", false)]);
321        let code = gen(&table);
322        assert!(code.contains("pub first_name: String"));
323        assert!(code.contains("sqlx(rename = \"firstName\")"));
324    }
325
326    #[test]
327    fn test_already_snake_case_no_rename() {
328        let table = make_table("users", vec![make_col("created_at", "text", false)]);
329        let code = gen(&table);
330        assert!(code.contains("pub created_at: String"));
331        assert!(!code.contains("sqlx(rename"));
332    }
333
334    // --- derives ---
335
336    #[test]
337    fn test_default_derives() {
338        let table = make_table("users", vec![make_col("id", "int4", false)]);
339        let code = gen(&table);
340        assert!(code.contains("Debug"));
341        assert!(code.contains("Clone"));
342        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
343    }
344
345    #[test]
346    fn test_extra_derive_serialize() {
347        let table = make_table("users", vec![make_col("id", "int4", false)]);
348        let schema = SchemaInfo::default();
349        let derives = vec!["Serialize".to_string()];
350        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
351        assert!(code.contains("Serialize"));
352    }
353
354    #[test]
355    fn test_extra_derives_both_serde() {
356        let table = make_table("users", vec![make_col("id", "int4", false)]);
357        let schema = SchemaInfo::default();
358        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
359        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
360        assert!(imports.iter().any(|i| i.contains("serde")));
361    }
362
363    // --- imports ---
364
365    #[test]
366    fn test_uuid_import() {
367        let table = make_table("users", vec![make_col("id", "uuid", false)]);
368        let schema = SchemaInfo::default();
369        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
370        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
371    }
372
373    #[test]
374    fn test_timestamptz_import() {
375        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
376        let schema = SchemaInfo::default();
377        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
378        assert!(imports.iter().any(|i| i.contains("chrono")));
379    }
380
381    #[test]
382    fn test_int4_only_serde_import() {
383        let table = make_table("users", vec![make_col("id", "int4", false)]);
384        let schema = SchemaInfo::default();
385        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
386        assert_eq!(imports.len(), 2);
387        assert!(imports.iter().any(|i| i.contains("serde")));
388        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
389    }
390
391    #[test]
392    fn test_multiple_imports_collected() {
393        let table = make_table("users", vec![
394            make_col("id", "uuid", false),
395            make_col("created_at", "timestamptz", false),
396        ]);
397        let schema = SchemaInfo::default();
398        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
399        assert!(imports.iter().any(|i| i.contains("uuid")));
400        assert!(imports.iter().any(|i| i.contains("chrono")));
401    }
402
403    // --- MySQL enum ---
404
405    #[test]
406    fn test_mysql_enum_column() {
407        let table = make_table("users", vec![ColumnInfo {
408            name: "status".to_string(),
409            data_type: "enum".to_string(),
410            udt_name: "enum('active','inactive')".to_string(),
411            is_nullable: false,
412            is_primary_key: false,
413            ordinal_position: 0,
414            schema_name: "test_db".to_string(),
415            column_default: None,
416        }]);
417        let schema = SchemaInfo::default();
418        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
419        assert!(code.contains("UsersStatus"));
420        assert!(imports.iter().any(|i| i.contains("super::types::")));
421    }
422
423    #[test]
424    fn test_mysql_enum_nullable() {
425        let table = make_table("users", vec![ColumnInfo {
426            name: "status".to_string(),
427            data_type: "enum".to_string(),
428            udt_name: "enum('a','b')".to_string(),
429            is_nullable: true,
430            is_primary_key: false,
431            ordinal_position: 0,
432            schema_name: "test_db".to_string(),
433            column_default: None,
434        }]);
435        let schema = SchemaInfo::default();
436        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
437        assert!(code.contains("Option<UsersStatus>"));
438    }
439
440    // --- type overrides ---
441
442    #[test]
443    fn test_type_override() {
444        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
445        let schema = SchemaInfo::default();
446        let mut overrides = HashMap::new();
447        overrides.insert("jsonb".to_string(), "MyJson".to_string());
448        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
449        assert!(code.contains("pub data: MyJson"));
450    }
451
452    #[test]
453    fn test_type_override_absent() {
454        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
455        let schema = SchemaInfo::default();
456        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
457        assert!(code.contains("Value"));
458    }
459
460    #[test]
461    fn test_type_override_nullable() {
462        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
463        let schema = SchemaInfo::default();
464        let mut overrides = HashMap::new();
465        overrides.insert("jsonb".to_string(), "MyJson".to_string());
466        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
467        assert!(code.contains("Option<MyJson>"));
468    }
469}