Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19    time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21    let mut imports = BTreeSet::new();
22    for imp in imports_for_derives(extra_derives) {
23        imports.insert(imp);
24    }
25    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27    // Build derive list
28    imports.insert("use serde::{Serialize, Deserialize};".to_string());
29    imports.insert("use sqlx_gen::SqlxGen;".to_string());
30    let mut derive_tokens = vec![
31        quote! { Debug },
32        quote! { Clone },
33        quote! { PartialEq },
34        quote! { Eq },
35        quote! { Serialize },
36        quote! { Deserialize },
37        quote! { sqlx::FromRow },
38        quote! { SqlxGen },
39    ];
40    for d in extra_derives {
41        let ident = format_ident!("{}", d);
42        derive_tokens.push(quote! { #ident });
43    }
44
45    // Build fields
46    let fields: Vec<TokenStream> = table
47        .columns
48        .iter()
49        .map(|col| {
50            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
51            if let Some(imp) = &rust_type.needs_import {
52                imports.insert(imp.clone());
53            }
54
55            let field_name_snake = col.name.to_snake_case();
56            // If the field name is a Rust keyword, prefix with table name
57            // e.g. column "type" on table "connector" → "connector_type"
58            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
59                let prefixed = format!(
60                    "{}_{}",
61                    table.name.to_snake_case(),
62                    field_name_snake
63                );
64                (prefixed, true)
65            } else {
66                let changed = field_name_snake != col.name;
67                (field_name_snake, changed)
68            };
69
70            let field_ident = format_ident!("{}", effective_name);
71            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72                let fallback = format_ident!("String");
73                quote! { #fallback }
74            });
75
76            let rename = if needs_rename {
77                let original = &col.name;
78                quote! { #[sqlx(rename = #original)] }
79            } else {
80                quote! {}
81            };
82
83            // Build #[sqlx_gen(...)] attribute with optional primary_key, sql_type, is_array
84            let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85            let has_pk = col.is_primary_key;
86            let has_sql_type = sql_type.is_some();
87
88            let sqlx_gen_attr = if has_pk || has_sql_type {
89                let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
90                let sql_type_part = match &sql_type {
91                    Some(t) => quote! { sql_type = #t, },
92                    None => quote! {},
93                };
94                let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
95                quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part)] }
96            } else {
97                quote! {}
98            };
99
100            quote! {
101                #sqlx_gen_attr
102                #rename
103                pub #field_ident: #type_tokens,
104            }
105        })
106        .collect();
107
108    let table_name_str = &table.name;
109    let schema_name_str = &table.schema_name;
110    let kind_str = if is_view { "view" } else { "table" };
111
112    let tokens = quote! {
113        #[derive(#(#derive_tokens),*)]
114        #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
115        pub struct #struct_name {
116            #(#fields)*
117        }
118    };
119
120    (tokens, imports)
121}
122
123/// Detect if a column uses a custom SQL type (enum or composite) and return the qualified
124/// SQL type name for casting, plus whether it's an array.
125/// Returns `(Some("type_name"), true)` for arrays of custom types,
126/// `(Some("type_name"), false)` for scalar custom types, or `(None, false)` for built-in types.
127fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
128    let (base_name, is_array) = match udt_name.strip_prefix('_') {
129        Some(inner) => (inner, true),
130        None => (udt_name, false),
131    };
132
133    // Check enums
134    if schema_info.enums.iter().any(|e| e.name == base_name) {
135        return (Some(base_name.to_string()), is_array);
136    }
137
138    // Check composite types
139    if schema_info.composite_types.iter().any(|c| c.name == base_name) {
140        return (Some(base_name.to_string()), is_array);
141    }
142
143    // Check if this is a non-builtin type that would hit the typemap fallback
144    // (e.g. range types like "timerange", "tsrange", etc.)
145    // Domains resolve to their base type, so they don't need marking.
146    let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
147    if !is_domain && !typemap::postgres::is_builtin(base_name) {
148        return (Some(base_name.to_string()), is_array);
149    }
150
151    (None, false)
152}
153
154fn resolve_column_type(
155    col: &crate::introspect::ColumnInfo,
156    db_kind: DatabaseKind,
157    table: &TableInfo,
158    schema_info: &SchemaInfo,
159    type_overrides: &HashMap<String, String>,
160    time_crate: TimeCrate,
161) -> typemap::RustType {
162    // For MySQL ENUM columns, resolve to the generated enum type
163    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
164        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
165        let rt = typemap::RustType::with_import(
166            &enum_type_name,
167            &format!("use super::types::{};", enum_type_name),
168        );
169        return if col.is_nullable {
170            rt.wrap_option()
171        } else {
172            rt
173        };
174    }
175
176    typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use crate::codegen::parse_and_format;
183    use crate::introspect::ColumnInfo;
184
185    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
186        TableInfo {
187            schema_name: "public".to_string(),
188            name: name.to_string(),
189            columns,
190        }
191    }
192
193    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
194        ColumnInfo {
195            name: name.to_string(),
196            data_type: udt_name.to_string(),
197            udt_name: udt_name.to_string(),
198            is_nullable: nullable,
199            is_primary_key: false,
200            ordinal_position: 0,
201            schema_name: "public".to_string(),
202            column_default: None,
203        }
204    }
205
206    fn gen(table: &TableInfo) -> String {
207        let schema = SchemaInfo::default();
208        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false, TimeCrate::Chrono);
209        parse_and_format(&tokens)
210    }
211
212    fn gen_with(
213        table: &TableInfo,
214        schema: &SchemaInfo,
215        db: DatabaseKind,
216        derives: &[String],
217        overrides: &HashMap<String, String>,
218    ) -> (String, BTreeSet<String>) {
219        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false, TimeCrate::Chrono);
220        (parse_and_format(&tokens), imports)
221    }
222
223    // --- basic structure ---
224
225    #[test]
226    fn test_simple_table() {
227        let table = make_table("users", vec![
228            make_col("id", "int4", false),
229            make_col("name", "text", false),
230        ]);
231        let code = gen(&table);
232        assert!(code.contains("pub id: i32"));
233        assert!(code.contains("pub name: String"));
234    }
235
236    #[test]
237    fn test_struct_name_pascal_case() {
238        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
239        let code = gen(&table);
240        assert!(code.contains("pub struct UserRoles"));
241    }
242
243    #[test]
244    fn test_struct_name_simple() {
245        let table = make_table("users", vec![make_col("id", "int4", false)]);
246        let code = gen(&table);
247        assert!(code.contains("pub struct Users"));
248    }
249
250    // --- nullable ---
251
252    #[test]
253    fn test_nullable_column() {
254        let table = make_table("users", vec![make_col("email", "text", true)]);
255        let code = gen(&table);
256        assert!(code.contains("pub email: Option<String>"));
257    }
258
259    #[test]
260    fn test_non_nullable_column() {
261        let table = make_table("users", vec![make_col("name", "text", false)]);
262        let code = gen(&table);
263        assert!(code.contains("pub name: String"));
264        assert!(!code.contains("Option"));
265    }
266
267    #[test]
268    fn test_mix_nullable() {
269        let table = make_table("users", vec![
270            make_col("id", "int4", false),
271            make_col("bio", "text", true),
272        ]);
273        let code = gen(&table);
274        assert!(code.contains("pub id: i32"));
275        assert!(code.contains("pub bio: Option<String>"));
276    }
277
278    // --- keyword renaming ---
279
280    #[test]
281    fn test_keyword_type_renamed() {
282        let table = make_table("connector", vec![make_col("type", "text", false)]);
283        let code = gen(&table);
284        assert!(code.contains("pub connector_type: String"));
285        assert!(code.contains("sqlx(rename = \"type\")"));
286    }
287
288    #[test]
289    fn test_keyword_fn_renamed() {
290        let table = make_table("item", vec![make_col("fn", "text", false)]);
291        let code = gen(&table);
292        assert!(code.contains("pub item_fn: String"));
293        assert!(code.contains("sqlx(rename = \"fn\")"));
294    }
295
296    #[test]
297    fn test_keyword_match_renamed() {
298        let table = make_table("game", vec![make_col("match", "text", false)]);
299        let code = gen(&table);
300        assert!(code.contains("pub game_match: String"));
301    }
302
303    #[test]
304    fn test_non_keyword_no_rename() {
305        let table = make_table("users", vec![make_col("name", "text", false)]);
306        let code = gen(&table);
307        assert!(!code.contains("sqlx(rename"));
308    }
309
310    // --- snake_case renaming ---
311
312    #[test]
313    fn test_camel_case_column_renamed() {
314        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
315        let code = gen(&table);
316        assert!(code.contains("pub created_at: String"));
317        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
318    }
319
320    #[test]
321    fn test_mixed_case_column_renamed() {
322        let table = make_table("users", vec![make_col("firstName", "text", false)]);
323        let code = gen(&table);
324        assert!(code.contains("pub first_name: String"));
325        assert!(code.contains("sqlx(rename = \"firstName\")"));
326    }
327
328    #[test]
329    fn test_already_snake_case_no_rename() {
330        let table = make_table("users", vec![make_col("created_at", "text", false)]);
331        let code = gen(&table);
332        assert!(code.contains("pub created_at: String"));
333        assert!(!code.contains("sqlx(rename"));
334    }
335
336    // --- derives ---
337
338    #[test]
339    fn test_default_derives() {
340        let table = make_table("users", vec![make_col("id", "int4", false)]);
341        let code = gen(&table);
342        assert!(code.contains("Debug"));
343        assert!(code.contains("Clone"));
344        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
345    }
346
347    #[test]
348    fn test_extra_derive_serialize() {
349        let table = make_table("users", vec![make_col("id", "int4", false)]);
350        let schema = SchemaInfo::default();
351        let derives = vec!["Serialize".to_string()];
352        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
353        assert!(code.contains("Serialize"));
354    }
355
356    #[test]
357    fn test_extra_derives_both_serde() {
358        let table = make_table("users", vec![make_col("id", "int4", false)]);
359        let schema = SchemaInfo::default();
360        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
361        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
362        assert!(imports.iter().any(|i| i.contains("serde")));
363    }
364
365    // --- imports ---
366
367    #[test]
368    fn test_uuid_import() {
369        let table = make_table("users", vec![make_col("id", "uuid", false)]);
370        let schema = SchemaInfo::default();
371        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
372        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
373    }
374
375    #[test]
376    fn test_timestamptz_import() {
377        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
378        let schema = SchemaInfo::default();
379        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
380        assert!(imports.iter().any(|i| i.contains("chrono")));
381    }
382
383    #[test]
384    fn test_int4_only_serde_import() {
385        let table = make_table("users", vec![make_col("id", "int4", false)]);
386        let schema = SchemaInfo::default();
387        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
388        assert_eq!(imports.len(), 2);
389        assert!(imports.iter().any(|i| i.contains("serde")));
390        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
391    }
392
393    #[test]
394    fn test_multiple_imports_collected() {
395        let table = make_table("users", vec![
396            make_col("id", "uuid", false),
397            make_col("created_at", "timestamptz", false),
398        ]);
399        let schema = SchemaInfo::default();
400        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
401        assert!(imports.iter().any(|i| i.contains("uuid")));
402        assert!(imports.iter().any(|i| i.contains("chrono")));
403    }
404
405    // --- MySQL enum ---
406
407    #[test]
408    fn test_mysql_enum_column() {
409        let table = make_table("users", vec![ColumnInfo {
410            name: "status".to_string(),
411            data_type: "enum".to_string(),
412            udt_name: "enum('active','inactive')".to_string(),
413            is_nullable: false,
414            is_primary_key: false,
415            ordinal_position: 0,
416            schema_name: "test_db".to_string(),
417            column_default: None,
418        }]);
419        let schema = SchemaInfo::default();
420        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
421        assert!(code.contains("UsersStatus"));
422        assert!(imports.iter().any(|i| i.contains("super::types::")));
423    }
424
425    #[test]
426    fn test_mysql_enum_nullable() {
427        let table = make_table("users", vec![ColumnInfo {
428            name: "status".to_string(),
429            data_type: "enum".to_string(),
430            udt_name: "enum('a','b')".to_string(),
431            is_nullable: true,
432            is_primary_key: false,
433            ordinal_position: 0,
434            schema_name: "test_db".to_string(),
435            column_default: None,
436        }]);
437        let schema = SchemaInfo::default();
438        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
439        assert!(code.contains("Option<UsersStatus>"));
440    }
441
442    // --- type overrides ---
443
444    #[test]
445    fn test_type_override() {
446        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
447        let schema = SchemaInfo::default();
448        let mut overrides = HashMap::new();
449        overrides.insert("jsonb".to_string(), "MyJson".to_string());
450        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
451        assert!(code.contains("pub data: MyJson"));
452    }
453
454    #[test]
455    fn test_type_override_absent() {
456        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
457        let schema = SchemaInfo::default();
458        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
459        assert!(code.contains("Value"));
460    }
461
462    #[test]
463    fn test_type_override_nullable() {
464        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
465        let schema = SchemaInfo::default();
466        let mut overrides = HashMap::new();
467        overrides.insert("jsonb".to_string(), "MyJson".to_string());
468        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
469        assert!(code.contains("Option<MyJson>"));
470    }
471}