Skip to main content

sqlx_gen/codegen/
struct_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13    table: &TableInfo,
14    db_kind: DatabaseKind,
15    schema_info: &SchemaInfo,
16    extra_derives: &[String],
17    type_overrides: &HashMap<String, String>,
18    is_view: bool,
19) -> (TokenStream, BTreeSet<String>) {
20    let mut imports = BTreeSet::new();
21    for imp in imports_for_derives(extra_derives) {
22        imports.insert(imp);
23    }
24    let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
25
26    // Build derive list
27    imports.insert("use serde::{Serialize, Deserialize};".to_string());
28    imports.insert("use sqlx_gen::SqlxGen;".to_string());
29    let mut derive_tokens = vec![
30        quote! { Debug },
31        quote! { Clone },
32        quote! { PartialEq },
33        quote! { Eq },
34        quote! { Serialize },
35        quote! { Deserialize },
36        quote! { sqlx::FromRow },
37        quote! { SqlxGen },
38    ];
39    for d in extra_derives {
40        let ident = format_ident!("{}", d);
41        derive_tokens.push(quote! { #ident });
42    }
43
44    // Build fields
45    let fields: Vec<TokenStream> = table
46        .columns
47        .iter()
48        .map(|col| {
49            let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides);
50            if let Some(imp) = &rust_type.needs_import {
51                imports.insert(imp.clone());
52            }
53
54            let field_name_snake = col.name.to_snake_case();
55            // If the field name is a Rust keyword, prefix with table name
56            // e.g. column "type" on table "connector" → "connector_type"
57            let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
58                let prefixed = format!(
59                    "{}_{}",
60                    table.name.to_snake_case(),
61                    field_name_snake
62                );
63                (prefixed, true)
64            } else {
65                let changed = field_name_snake != col.name;
66                (field_name_snake, changed)
67            };
68
69            let field_ident = format_ident!("{}", effective_name);
70            let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
71                let fallback = format_ident!("String");
72                quote! { #fallback }
73            });
74
75            let rename = if needs_rename {
76                let original = &col.name;
77                quote! { #[sqlx(rename = #original)] }
78            } else {
79                quote! {}
80            };
81
82            let pk_attr = if col.is_primary_key {
83                quote! { #[sqlx_gen(primary_key)] }
84            } else {
85                quote! {}
86            };
87
88            quote! {
89                #pk_attr
90                #rename
91                pub #field_ident: #type_tokens,
92            }
93        })
94        .collect();
95
96    let table_name_str = &table.name;
97    let kind_str = if is_view { "view" } else { "table" };
98
99    let tokens = quote! {
100        #[derive(#(#derive_tokens),*)]
101        #[sqlx_gen(kind = #kind_str, table = #table_name_str)]
102        pub struct #struct_name {
103            #(#fields)*
104        }
105    };
106
107    (tokens, imports)
108}
109
110fn resolve_column_type(
111    col: &crate::introspect::ColumnInfo,
112    db_kind: DatabaseKind,
113    table: &TableInfo,
114    schema_info: &SchemaInfo,
115    type_overrides: &HashMap<String, String>,
116) -> typemap::RustType {
117    // For MySQL ENUM columns, resolve to the generated enum type
118    if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
119        let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
120        let rt = typemap::RustType::with_import(
121            &enum_type_name,
122            &format!("use super::types::{};", enum_type_name),
123        );
124        return if col.is_nullable {
125            rt.wrap_option()
126        } else {
127            rt
128        };
129    }
130
131    typemap::map_column(col, db_kind, schema_info, type_overrides)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use crate::codegen::parse_and_format;
138    use crate::introspect::ColumnInfo;
139
140    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
141        TableInfo {
142            schema_name: "public".to_string(),
143            name: name.to_string(),
144            columns,
145        }
146    }
147
148    fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
149        ColumnInfo {
150            name: name.to_string(),
151            data_type: udt_name.to_string(),
152            udt_name: udt_name.to_string(),
153            is_nullable: nullable,
154            is_primary_key: false,
155            ordinal_position: 0,
156            schema_name: "public".to_string(),
157        }
158    }
159
160    fn gen(table: &TableInfo) -> String {
161        let schema = SchemaInfo::default();
162        let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false);
163        parse_and_format(&tokens)
164    }
165
166    fn gen_with(
167        table: &TableInfo,
168        schema: &SchemaInfo,
169        db: DatabaseKind,
170        derives: &[String],
171        overrides: &HashMap<String, String>,
172    ) -> (String, BTreeSet<String>) {
173        let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false);
174        (parse_and_format(&tokens), imports)
175    }
176
177    // --- basic structure ---
178
179    #[test]
180    fn test_simple_table() {
181        let table = make_table("users", vec![
182            make_col("id", "int4", false),
183            make_col("name", "text", false),
184        ]);
185        let code = gen(&table);
186        assert!(code.contains("pub id: i32"));
187        assert!(code.contains("pub name: String"));
188    }
189
190    #[test]
191    fn test_struct_name_pascal_case() {
192        let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
193        let code = gen(&table);
194        assert!(code.contains("pub struct UserRoles"));
195    }
196
197    #[test]
198    fn test_struct_name_simple() {
199        let table = make_table("users", vec![make_col("id", "int4", false)]);
200        let code = gen(&table);
201        assert!(code.contains("pub struct Users"));
202    }
203
204    // --- nullable ---
205
206    #[test]
207    fn test_nullable_column() {
208        let table = make_table("users", vec![make_col("email", "text", true)]);
209        let code = gen(&table);
210        assert!(code.contains("pub email: Option<String>"));
211    }
212
213    #[test]
214    fn test_non_nullable_column() {
215        let table = make_table("users", vec![make_col("name", "text", false)]);
216        let code = gen(&table);
217        assert!(code.contains("pub name: String"));
218        assert!(!code.contains("Option"));
219    }
220
221    #[test]
222    fn test_mix_nullable() {
223        let table = make_table("users", vec![
224            make_col("id", "int4", false),
225            make_col("bio", "text", true),
226        ]);
227        let code = gen(&table);
228        assert!(code.contains("pub id: i32"));
229        assert!(code.contains("pub bio: Option<String>"));
230    }
231
232    // --- keyword renaming ---
233
234    #[test]
235    fn test_keyword_type_renamed() {
236        let table = make_table("connector", vec![make_col("type", "text", false)]);
237        let code = gen(&table);
238        assert!(code.contains("pub connector_type: String"));
239        assert!(code.contains("sqlx(rename = \"type\")"));
240    }
241
242    #[test]
243    fn test_keyword_fn_renamed() {
244        let table = make_table("item", vec![make_col("fn", "text", false)]);
245        let code = gen(&table);
246        assert!(code.contains("pub item_fn: String"));
247        assert!(code.contains("sqlx(rename = \"fn\")"));
248    }
249
250    #[test]
251    fn test_keyword_match_renamed() {
252        let table = make_table("game", vec![make_col("match", "text", false)]);
253        let code = gen(&table);
254        assert!(code.contains("pub game_match: String"));
255    }
256
257    #[test]
258    fn test_non_keyword_no_rename() {
259        let table = make_table("users", vec![make_col("name", "text", false)]);
260        let code = gen(&table);
261        assert!(!code.contains("sqlx(rename"));
262    }
263
264    // --- snake_case renaming ---
265
266    #[test]
267    fn test_camel_case_column_renamed() {
268        let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
269        let code = gen(&table);
270        assert!(code.contains("pub created_at: String"));
271        assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
272    }
273
274    #[test]
275    fn test_mixed_case_column_renamed() {
276        let table = make_table("users", vec![make_col("firstName", "text", false)]);
277        let code = gen(&table);
278        assert!(code.contains("pub first_name: String"));
279        assert!(code.contains("sqlx(rename = \"firstName\")"));
280    }
281
282    #[test]
283    fn test_already_snake_case_no_rename() {
284        let table = make_table("users", vec![make_col("created_at", "text", false)]);
285        let code = gen(&table);
286        assert!(code.contains("pub created_at: String"));
287        assert!(!code.contains("sqlx(rename"));
288    }
289
290    // --- derives ---
291
292    #[test]
293    fn test_default_derives() {
294        let table = make_table("users", vec![make_col("id", "int4", false)]);
295        let code = gen(&table);
296        assert!(code.contains("Debug"));
297        assert!(code.contains("Clone"));
298        assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
299    }
300
301    #[test]
302    fn test_extra_derive_serialize() {
303        let table = make_table("users", vec![make_col("id", "int4", false)]);
304        let schema = SchemaInfo::default();
305        let derives = vec!["Serialize".to_string()];
306        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
307        assert!(code.contains("Serialize"));
308    }
309
310    #[test]
311    fn test_extra_derives_both_serde() {
312        let table = make_table("users", vec![make_col("id", "int4", false)]);
313        let schema = SchemaInfo::default();
314        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
315        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
316        assert!(imports.iter().any(|i| i.contains("serde")));
317    }
318
319    // --- imports ---
320
321    #[test]
322    fn test_uuid_import() {
323        let table = make_table("users", vec![make_col("id", "uuid", false)]);
324        let schema = SchemaInfo::default();
325        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
326        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
327    }
328
329    #[test]
330    fn test_timestamptz_import() {
331        let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
332        let schema = SchemaInfo::default();
333        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
334        assert!(imports.iter().any(|i| i.contains("chrono")));
335    }
336
337    #[test]
338    fn test_int4_only_serde_import() {
339        let table = make_table("users", vec![make_col("id", "int4", false)]);
340        let schema = SchemaInfo::default();
341        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
342        assert_eq!(imports.len(), 2);
343        assert!(imports.iter().any(|i| i.contains("serde")));
344        assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
345    }
346
347    #[test]
348    fn test_multiple_imports_collected() {
349        let table = make_table("users", vec![
350            make_col("id", "uuid", false),
351            make_col("created_at", "timestamptz", false),
352        ]);
353        let schema = SchemaInfo::default();
354        let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
355        assert!(imports.iter().any(|i| i.contains("uuid")));
356        assert!(imports.iter().any(|i| i.contains("chrono")));
357    }
358
359    // --- MySQL enum ---
360
361    #[test]
362    fn test_mysql_enum_column() {
363        let table = make_table("users", vec![ColumnInfo {
364            name: "status".to_string(),
365            data_type: "enum".to_string(),
366            udt_name: "enum('active','inactive')".to_string(),
367            is_nullable: false,
368            is_primary_key: false,
369            ordinal_position: 0,
370            schema_name: "test_db".to_string(),
371        }]);
372        let schema = SchemaInfo::default();
373        let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
374        assert!(code.contains("UsersStatus"));
375        assert!(imports.iter().any(|i| i.contains("super::types::")));
376    }
377
378    #[test]
379    fn test_mysql_enum_nullable() {
380        let table = make_table("users", vec![ColumnInfo {
381            name: "status".to_string(),
382            data_type: "enum".to_string(),
383            udt_name: "enum('a','b')".to_string(),
384            is_nullable: true,
385            is_primary_key: false,
386            ordinal_position: 0,
387            schema_name: "test_db".to_string(),
388        }]);
389        let schema = SchemaInfo::default();
390        let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
391        assert!(code.contains("Option<UsersStatus>"));
392    }
393
394    // --- type overrides ---
395
396    #[test]
397    fn test_type_override() {
398        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
399        let schema = SchemaInfo::default();
400        let mut overrides = HashMap::new();
401        overrides.insert("jsonb".to_string(), "MyJson".to_string());
402        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
403        assert!(code.contains("pub data: MyJson"));
404    }
405
406    #[test]
407    fn test_type_override_absent() {
408        let table = make_table("users", vec![make_col("data", "jsonb", false)]);
409        let schema = SchemaInfo::default();
410        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
411        assert!(code.contains("Value"));
412    }
413
414    #[test]
415    fn test_type_override_nullable() {
416        let table = make_table("users", vec![make_col("data", "jsonb", true)]);
417        let schema = SchemaInfo::default();
418        let mut overrides = HashMap::new();
419        overrides.insert("jsonb".to_string(), "MyJson".to_string());
420        let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
421        assert!(code.contains("Option<MyJson>"));
422    }
423}