Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18) -> (TokenStream, BTreeSet<String>) {
19    let mut imports = BTreeSet::new();
20    for imp in imports_for_derives(extra_derives) {
21        imports.insert(imp);
22    }
23    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
24
25    // Build derive list
26    let mut derive_tokens = vec![
27        quote! { Debug },
28        quote! { Clone },
29        quote! { sqlx::FromRow },
30    ];
31    for d in extra_derives {
32        let ident = format_ident!("{}", d);
33        derive_tokens.push(quote! { #ident });
34    }
35
36    // Build fields
37    let fields: Vec<TokenStream> = table
38        .columns
39        .iter()
40        .map(|col| {
41            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
42            if let Some(imp) = &rust_type.needs_import {
43                imports.insert(imp.clone());
44            }
45
46            let field_name_snake = col.name.to_snake_case();
47            // If the field name is a Rust keyword, prefix with table name
48            // e.g. column "type" on table "connector" → "connector_type"
49            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
50                let prefixed = format!(
51                    "{}_{}",
52                    table.name.to_snake_case(),
53                    field_name_snake
54                );
55                (prefixed, true)
56            } else {
57                let changed = field_name_snake != col.name;
58                (field_name_snake, changed)
59            };
60
61            let field_ident = format_ident!("{}", effective_name);
62            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
63                let fallback = format_ident!("String");
64                quote! { #fallback }
65            });
66
67            let rename = if needs_rename {
68                let original = &col.name;
69                quote! { #[sqlx(rename = #original)] }
70            } else {
71                quote! {}
72            };
73
74            quote! {
75                #rename
76                pub #field_ident: #type_tokens,
77            }
78        })
79        .collect();
80
81    let tokens = quote! {
82        #[derive(#(#derive_tokens),*)]
83        pub struct #struct_name {
84            #(#fields)*
85        }
86    };
87
88    (tokens, imports)
89}
90
91fn resolve_column_type(
92    col: &crate::introspect::ColumnInfo,
93    db_kind: DatabaseKind,
94    table: &TableInfo,
95    schema_info: &SchemaInfo,
96    type_overrides: &HashMap<String, String>,
97) -> typemap::RustType {
98    // For MySQL ENUM columns, resolve to the generated enum type
99    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
100        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
101        let rt = typemap::RustType::with_import(
102            &enum_type_name,
103            &format!("use super::types::{};", enum_type_name),
104        );
105        return if col.is_nullable {
106            rt.wrap_option()
107        } else {
108            rt
109        };
110    }
111
112    typemap::map_column(col, db_kind, schema_info, type_overrides)
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::codegen::parse_and_format;
119    use crate::introspect::ColumnInfo;
120
121    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
122        TableInfo {
123            schema_name: "public".to_string(),
124            name: name.to_string(),
125            columns,
126        }
127    }
128
129    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
130        ColumnInfo {
131            name: name.to_string(),
132            data_type: udt_name.to_string(),
133            udt_name: udt_name.to_string(),
134            is_nullable: nullable,
135            ordinal_position: 0,
136            schema_name: "public".to_string(),
137        }
138    }
139
140    fn gen(table: &TableInfo) -> String {
141        let schema = SchemaInfo::default();
142        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new());
143        parse_and_format(&tokens)
144    }
145
146    fn gen_with(
147        table: &TableInfo,
148        schema: &SchemaInfo,
149        db: DatabaseKind,
150        derives: &[String],
151        overrides: &HashMap<String, String>,
152    ) -> (String, BTreeSet<String>) {
153        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides);
154        (parse_and_format(&tokens), imports)
155    }
156
157    // --- basic structure ---
158
159    #[test]
160    fn test_simple_table() {
161        let table = make_table("users", vec![
162            make_col("id", "int4", false),
163            make_col("name", "text", false),
164        ]);
165        let code = gen(&table);
166        assert!(code.contains("pub id: i32"));
167        assert!(code.contains("pub name: String"));
168    }
169
170    #[test]
171    fn test_struct_name_pascal_case() {
172        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
173        let code = gen(&table);
174        assert!(code.contains("pub struct UserRoles"));
175    }
176
177    #[test]
178    fn test_struct_name_simple() {
179        let table = make_table("users", vec![make_col("id", "int4", false)]);
180        let code = gen(&table);
181        assert!(code.contains("pub struct Users"));
182    }
183
184    // --- nullable ---
185
186    #[test]
187    fn test_nullable_column() {
188        let table = make_table("users", vec![make_col("email", "text", true)]);
189        let code = gen(&table);
190        assert!(code.contains("pub email: Option<String>"));
191    }
192
193    #[test]
194    fn test_non_nullable_column() {
195        let table = make_table("users", vec![make_col("name", "text", false)]);
196        let code = gen(&table);
197        assert!(code.contains("pub name: String"));
198        assert!(!code.contains("Option"));
199    }
200
201    #[test]
202    fn test_mix_nullable() {
203        let table = make_table("users", vec![
204            make_col("id", "int4", false),
205            make_col("bio", "text", true),
206        ]);
207        let code = gen(&table);
208        assert!(code.contains("pub id: i32"));
209        assert!(code.contains("pub bio: Option<String>"));
210    }
211
212    // --- keyword renaming ---
213
214    #[test]
215    fn test_keyword_type_renamed() {
216        let table = make_table("connector", vec![make_col("type", "text", false)]);
217        let code = gen(&table);
218        assert!(code.contains("pub connector_type: String"));
219        assert!(code.contains("sqlx(rename = \"type\")"));
220    }
221
222    #[test]
223    fn test_keyword_fn_renamed() {
224        let table = make_table("item", vec![make_col("fn", "text", false)]);
225        let code = gen(&table);
226        assert!(code.contains("pub item_fn: String"));
227        assert!(code.contains("sqlx(rename = \"fn\")"));
228    }
229
230    #[test]
231    fn test_keyword_match_renamed() {
232        let table = make_table("game", vec![make_col("match", "text", false)]);
233        let code = gen(&table);
234        assert!(code.contains("pub game_match: String"));
235    }
236
237    #[test]
238    fn test_non_keyword_no_rename() {
239        let table = make_table("users", vec![make_col("name", "text", false)]);
240        let code = gen(&table);
241        assert!(!code.contains("sqlx(rename"));
242    }
243
244    // --- snake_case renaming ---
245
246    #[test]
247    fn test_camel_case_column_renamed() {
248        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
249        let code = gen(&table);
250        assert!(code.contains("pub created_at: String"));
251        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
252    }
253
254    #[test]
255    fn test_mixed_case_column_renamed() {
256        let table = make_table("users", vec![make_col("firstName", "text", false)]);
257        let code = gen(&table);
258        assert!(code.contains("pub first_name: String"));
259        assert!(code.contains("sqlx(rename = \"firstName\")"));
260    }
261
262    #[test]
263    fn test_already_snake_case_no_rename() {
264        let table = make_table("users", vec![make_col("created_at", "text", false)]);
265        let code = gen(&table);
266        assert!(code.contains("pub created_at: String"));
267        assert!(!code.contains("sqlx(rename"));
268    }
269
270    // --- derives ---
271
272    #[test]
273    fn test_default_derives() {
274        let table = make_table("users", vec![make_col("id", "int4", false)]);
275        let code = gen(&table);
276        assert!(code.contains("Debug"));
277        assert!(code.contains("Clone"));
278        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
279    }
280
281    #[test]
282    fn test_extra_derive_serialize() {
283        let table = make_table("users", vec![make_col("id", "int4", false)]);
284        let schema = SchemaInfo::default();
285        let derives = vec!["Serialize".to_string()];
286        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
287        assert!(code.contains("Serialize"));
288    }
289
290    #[test]
291    fn test_extra_derives_both_serde() {
292        let table = make_table("users", vec![make_col("id", "int4", false)]);
293        let schema = SchemaInfo::default();
294        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
295        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
296        assert!(imports.iter().any(|i| i.contains("serde")));
297    }
298
299    // --- imports ---
300
301    #[test]
302    fn test_uuid_import() {
303        let table = make_table("users", vec![make_col("id", "uuid", false)]);
304        let schema = SchemaInfo::default();
305        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
306        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
307    }
308
309    #[test]
310    fn test_timestamptz_import() {
311        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
312        let schema = SchemaInfo::default();
313        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
314        assert!(imports.iter().any(|i| i.contains("chrono")));
315    }
316
317    #[test]
318    fn test_int4_no_import() {
319        let table = make_table("users", vec![make_col("id", "int4", false)]);
320        let schema = SchemaInfo::default();
321        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
322        assert!(imports.is_empty());
323    }
324
325    #[test]
326    fn test_multiple_imports_collected() {
327        let table = make_table("users", vec![
328            make_col("id", "uuid", false),
329            make_col("created_at", "timestamptz", false),
330        ]);
331        let schema = SchemaInfo::default();
332        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
333        assert!(imports.iter().any(|i| i.contains("uuid")));
334        assert!(imports.iter().any(|i| i.contains("chrono")));
335    }
336
337    // --- MySQL enum ---
338
339    #[test]
340    fn test_mysql_enum_column() {
341        let table = make_table("users", vec![ColumnInfo {
342            name: "status".to_string(),
343            data_type: "enum".to_string(),
344            udt_name: "enum('active','inactive')".to_string(),
345            is_nullable: false,
346            ordinal_position: 0,
347            schema_name: "test_db".to_string(),
348        }]);
349        let schema = SchemaInfo::default();
350        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
351        assert!(code.contains("UsersStatus"));
352        assert!(imports.iter().any(|i| i.contains("super::types::")));
353    }
354
355    #[test]
356    fn test_mysql_enum_nullable() {
357        let table = make_table("users", vec![ColumnInfo {
358            name: "status".to_string(),
359            data_type: "enum".to_string(),
360            udt_name: "enum('a','b')".to_string(),
361            is_nullable: true,
362            ordinal_position: 0,
363            schema_name: "test_db".to_string(),
364        }]);
365        let schema = SchemaInfo::default();
366        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
367        assert!(code.contains("Option<UsersStatus>"));
368    }
369
370    // --- type overrides ---
371
372    #[test]
373    fn test_type_override() {
374        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
375        let schema = SchemaInfo::default();
376        let mut overrides = HashMap::new();
377        overrides.insert("jsonb".to_string(), "MyJson".to_string());
378        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
379        assert!(code.contains("pub data: MyJson"));
380    }
381
382    #[test]
383    fn test_type_override_absent() {
384        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
385        let schema = SchemaInfo::default();
386        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
387        assert!(code.contains("Value"));
388    }
389
390    #[test]
391    fn test_type_override_nullable() {
392        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
393        let schema = SchemaInfo::default();
394        let mut overrides = HashMap::new();
395        overrides.insert("jsonb".to_string(), "MyJson".to_string());
396        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
397        assert!(code.contains("Option<MyJson>"));
398    }
399}