Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19    let mut imports = BTreeSet::new();
20    for imp in imports_for_derives(extra_derives) {
21        imports.insert(imp);
22    }
23    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
24
25    // Build derive list
26    imports.insert("use serde::{Serialize, Deserialize};".to_string());
27    let mut derive_tokens = vec![
28        quote! { Debug },
29        quote! { Clone },
30        quote! { PartialEq },
31        quote! { Eq },
32        quote! { Serialize },
33        quote! { Deserialize },
34        quote! { sqlx::FromRow },
35    ];
36    for d in extra_derives {
37        let ident = format_ident!("{}", d);
38        derive_tokens.push(quote! { #ident });
39    }
40
41    // Build fields
42    let fields: Vec<TokenStream> = table
43        .columns
44        .iter()
45        .map(|col| {
46            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
47            if let Some(imp) = &rust_type.needs_import {
48                imports.insert(imp.clone());
49            }
50
51            let field_name_snake = col.name.to_snake_case();
52            // If the field name is a Rust keyword, prefix with table name
53            // e.g. column "type" on table "connector" → "connector_type"
54            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
55                let prefixed = format!(
56                    "{}_{}",
57                    table.name.to_snake_case(),
58                    field_name_snake
59                );
60                (prefixed, true)
61            } else {
62                let changed = field_name_snake != col.name;
63                (field_name_snake, changed)
64            };
65
66            let field_ident = format_ident!("{}", effective_name);
67            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
68                let fallback = format_ident!("String");
69                quote! { #fallback }
70            });
71
72            let rename = if needs_rename {
73                let original = &col.name;
74                quote! { #[sqlx(rename = #original)] }
75            } else {
76                quote! {}
77            };
78
79            quote! {
80                #rename
81                pub #field_ident: #type_tokens,
82            }
83        })
84        .collect();
85
86    let tokens = quote! {
87        #[derive(#(#derive_tokens),*)]
88        pub struct #struct_name {
89            #(#fields)*
90        }
91    };
92
93    (tokens, imports)
94}
95
96fn resolve_column_type(
97    col: &crate::introspect::ColumnInfo,
98    db_kind: DatabaseKind,
99    table: &TableInfo,
100    schema_info: &SchemaInfo,
101    type_overrides: &HashMap<String, String>,
102) -> typemap::RustType {
103    // For MySQL ENUM columns, resolve to the generated enum type
104    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
105        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
106        let rt = typemap::RustType::with_import(
107            &enum_type_name,
108            &format!("use super::types::{};", enum_type_name),
109        );
110        return if col.is_nullable {
111            rt.wrap_option()
112        } else {
113            rt
114        };
115    }
116
117    typemap::map_column(col, db_kind, schema_info, type_overrides)
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::codegen::parse_and_format;
124    use crate::introspect::ColumnInfo;
125
126    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
127        TableInfo {
128            schema_name: "public".to_string(),
129            name: name.to_string(),
130            columns,
131        }
132    }
133
134    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
135        ColumnInfo {
136            name: name.to_string(),
137            data_type: udt_name.to_string(),
138            udt_name: udt_name.to_string(),
139            is_nullable: nullable,
140            ordinal_position: 0,
141            schema_name: "public".to_string(),
142        }
143    }
144
145    fn gen(table: &TableInfo) -> String {
146        let schema = SchemaInfo::default();
147        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
148        parse_and_format(&tokens)
149    }
150
151    fn gen_with(
152        table: &TableInfo,
153        schema: &SchemaInfo,
154        db: DatabaseKind,
155        derives: &[String],
156        overrides: &HashMap<String, String>,
157    ) -> (String, BTreeSet<String>) {
158        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides);
159        (parse_and_format(&tokens), imports)
160    }
161
162    // --- basic structure ---
163
164    #[test]
165    fn test_simple_table() {
166        let table = make_table("users", vec![
167            make_col("id", "int4", false),
168            make_col("name", "text", false),
169        ]);
170        let code = gen(&table);
171        assert!(code.contains("pub id: i32"));
172        assert!(code.contains("pub name: String"));
173    }
174
175    #[test]
176    fn test_struct_name_pascal_case() {
177        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
178        let code = gen(&table);
179        assert!(code.contains("pub struct UserRoles"));
180    }
181
182    #[test]
183    fn test_struct_name_simple() {
184        let table = make_table("users", vec![make_col("id", "int4", false)]);
185        let code = gen(&table);
186        assert!(code.contains("pub struct Users"));
187    }
188
189    // --- nullable ---
190
191    #[test]
192    fn test_nullable_column() {
193        let table = make_table("users", vec![make_col("email", "text", true)]);
194        let code = gen(&table);
195        assert!(code.contains("pub email: Option<String>"));
196    }
197
198    #[test]
199    fn test_non_nullable_column() {
200        let table = make_table("users", vec![make_col("name", "text", false)]);
201        let code = gen(&table);
202        assert!(code.contains("pub name: String"));
203        assert!(!code.contains("Option"));
204    }
205
206    #[test]
207    fn test_mix_nullable() {
208        let table = make_table("users", vec![
209            make_col("id", "int4", false),
210            make_col("bio", "text", true),
211        ]);
212        let code = gen(&table);
213        assert!(code.contains("pub id: i32"));
214        assert!(code.contains("pub bio: Option<String>"));
215    }
216
217    // --- keyword renaming ---
218
219    #[test]
220    fn test_keyword_type_renamed() {
221        let table = make_table("connector", vec![make_col("type", "text", false)]);
222        let code = gen(&table);
223        assert!(code.contains("pub connector_type: String"));
224        assert!(code.contains("sqlx(rename = \"type\")"));
225    }
226
227    #[test]
228    fn test_keyword_fn_renamed() {
229        let table = make_table("item", vec![make_col("fn", "text", false)]);
230        let code = gen(&table);
231        assert!(code.contains("pub item_fn: String"));
232        assert!(code.contains("sqlx(rename = \"fn\")"));
233    }
234
235    #[test]
236    fn test_keyword_match_renamed() {
237        let table = make_table("game", vec![make_col("match", "text", false)]);
238        let code = gen(&table);
239        assert!(code.contains("pub game_match: String"));
240    }
241
242    #[test]
243    fn test_non_keyword_no_rename() {
244        let table = make_table("users", vec![make_col("name", "text", false)]);
245        let code = gen(&table);
246        assert!(!code.contains("sqlx(rename"));
247    }
248
249    // --- snake_case renaming ---
250
251    #[test]
252    fn test_camel_case_column_renamed() {
253        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
254        let code = gen(&table);
255        assert!(code.contains("pub created_at: String"));
256        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
257    }
258
259    #[test]
260    fn test_mixed_case_column_renamed() {
261        let table = make_table("users", vec![make_col("firstName", "text", false)]);
262        let code = gen(&table);
263        assert!(code.contains("pub first_name: String"));
264        assert!(code.contains("sqlx(rename = \"firstName\")"));
265    }
266
267    #[test]
268    fn test_already_snake_case_no_rename() {
269        let table = make_table("users", vec![make_col("created_at", "text", false)]);
270        let code = gen(&table);
271        assert!(code.contains("pub created_at: String"));
272        assert!(!code.contains("sqlx(rename"));
273    }
274
275    // --- derives ---
276
277    #[test]
278    fn test_default_derives() {
279        let table = make_table("users", vec![make_col("id", "int4", false)]);
280        let code = gen(&table);
281        assert!(code.contains("Debug"));
282        assert!(code.contains("Clone"));
283        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
284    }
285
286    #[test]
287    fn test_extra_derive_serialize() {
288        let table = make_table("users", vec![make_col("id", "int4", false)]);
289        let schema = SchemaInfo::default();
290        let derives = vec!["Serialize".to_string()];
291        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
292        assert!(code.contains("Serialize"));
293    }
294
295    #[test]
296    fn test_extra_derives_both_serde() {
297        let table = make_table("users", vec![make_col("id", "int4", false)]);
298        let schema = SchemaInfo::default();
299        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
300        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
301        assert!(imports.iter().any(|i| i.contains("serde")));
302    }
303
304    // --- imports ---
305
306    #[test]
307    fn test_uuid_import() {
308        let table = make_table("users", vec![make_col("id", "uuid", false)]);
309        let schema = SchemaInfo::default();
310        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
311        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
312    }
313
314    #[test]
315    fn test_timestamptz_import() {
316        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
317        let schema = SchemaInfo::default();
318        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
319        assert!(imports.iter().any(|i| i.contains("chrono")));
320    }
321
322    #[test]
323    fn test_int4_only_serde_import() {
324        let table = make_table("users", vec![make_col("id", "int4", false)]);
325        let schema = SchemaInfo::default();
326        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
327        assert_eq!(imports.len(), 1);
328        assert!(imports.iter().any(|i| i.contains("serde")));
329    }
330
331    #[test]
332    fn test_multiple_imports_collected() {
333        let table = make_table("users", vec![
334            make_col("id", "uuid", false),
335            make_col("created_at", "timestamptz", false),
336        ]);
337        let schema = SchemaInfo::default();
338        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
339        assert!(imports.iter().any(|i| i.contains("uuid")));
340        assert!(imports.iter().any(|i| i.contains("chrono")));
341    }
342
343    // --- MySQL enum ---
344
345    #[test]
346    fn test_mysql_enum_column() {
347        let table = make_table("users", vec![ColumnInfo {
348            name: "status".to_string(),
349            data_type: "enum".to_string(),
350            udt_name: "enum('active','inactive')".to_string(),
351            is_nullable: false,
352            ordinal_position: 0,
353            schema_name: "test_db".to_string(),
354        }]);
355        let schema = SchemaInfo::default();
356        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
357        assert!(code.contains("UsersStatus"));
358        assert!(imports.iter().any(|i| i.contains("super::types::")));
359    }
360
361    #[test]
362    fn test_mysql_enum_nullable() {
363        let table = make_table("users", vec![ColumnInfo {
364            name: "status".to_string(),
365            data_type: "enum".to_string(),
366            udt_name: "enum('a','b')".to_string(),
367            is_nullable: true,
368            ordinal_position: 0,
369            schema_name: "test_db".to_string(),
370        }]);
371        let schema = SchemaInfo::default();
372        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
373        assert!(code.contains("Option<UsersStatus>"));
374    }
375
376    // --- type overrides ---
377
378    #[test]
379    fn test_type_override() {
380        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
381        let schema = SchemaInfo::default();
382        let mut overrides = HashMap::new();
383        overrides.insert("jsonb".to_string(), "MyJson".to_string());
384        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
385        assert!(code.contains("pub data: MyJson"));
386    }
387
388    #[test]
389    fn test_type_override_absent() {
390        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
391        let schema = SchemaInfo::default();
392        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
393        assert!(code.contains("Value"));
394    }
395
396    #[test]
397    fn test_type_override_nullable() {
398        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
399        let schema = SchemaInfo::default();
400        let mut overrides = HashMap::new();
401        overrides.insert("jsonb".to_string(), "MyJson".to_string());
402        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
403        assert!(code.contains("Option<MyJson>"));
404    }
405}