1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 is_view: bool,
19 time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21 let mut imports = BTreeSet::new();
22 for imp in imports_for_derives(extra_derives) {
23 imports.insert(imp);
24 }
25 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27 imports.insert("use serde::{Serialize, Deserialize};".to_string());
29 imports.insert("use sqlx_gen::SqlxGen;".to_string());
30 let mut derive_tokens = vec![
31 quote! { Debug },
32 quote! { Clone },
33 quote! { PartialEq },
34 quote! { Eq },
35 quote! { Serialize },
36 quote! { Deserialize },
37 quote! { sqlx::FromRow },
38 quote! { SqlxGen },
39 ];
40 for d in extra_derives {
41 let ident = format_ident!("{}", d);
42 derive_tokens.push(quote! { #ident });
43 }
44
45 let fields: Vec<TokenStream> = table
47 .columns
48 .iter()
49 .map(|col| {
50 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
51 if let Some(imp) = &rust_type.needs_import {
52 imports.insert(imp.clone());
53 }
54
55 let field_name_snake = col.name.to_snake_case();
56 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
59 let prefixed = format!(
60 "{}_{}",
61 table.name.to_snake_case(),
62 field_name_snake
63 );
64 (prefixed, true)
65 } else {
66 let changed = field_name_snake != col.name;
67 (field_name_snake, changed)
68 };
69
70 let field_ident = format_ident!("{}", effective_name);
71 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72 let fallback = format_ident!("String");
73 quote! { #fallback }
74 });
75
76 let rename = if needs_rename {
77 let original = &col.name;
78 quote! { #[sqlx(rename = #original)] }
79 } else {
80 quote! {}
81 };
82
83 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85 let has_pk = col.is_primary_key;
86 let has_sql_type = sql_type.is_some();
87
88 let sqlx_gen_attr = if has_pk || has_sql_type {
89 let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
90 let sql_type_part = match &sql_type {
91 Some(t) => quote! { sql_type = #t, },
92 None => quote! {},
93 };
94 let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
95 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part)] }
96 } else {
97 quote! {}
98 };
99
100 quote! {
101 #sqlx_gen_attr
102 #rename
103 pub #field_ident: #type_tokens,
104 }
105 })
106 .collect();
107
108 let table_name_str = &table.name;
109 let schema_name_str = &table.schema_name;
110 let kind_str = if is_view { "view" } else { "table" };
111
112 let tokens = quote! {
113 #[derive(#(#derive_tokens),*)]
114 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
115 pub struct #struct_name {
116 #(#fields)*
117 }
118 };
119
120 (tokens, imports)
121}
122
123fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
128 let (base_name, is_array) = match udt_name.strip_prefix('_') {
129 Some(inner) => (inner, true),
130 None => (udt_name, false),
131 };
132
133 if schema_info.enums.iter().any(|e| e.name == base_name) {
135 return (Some(base_name.to_string()), is_array);
136 }
137
138 if schema_info.composite_types.iter().any(|c| c.name == base_name) {
140 return (Some(base_name.to_string()), is_array);
141 }
142
143 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
147 if !is_domain && !typemap::postgres::is_builtin(base_name) {
148 return (Some(base_name.to_string()), is_array);
149 }
150
151 (None, false)
152}
153
154fn resolve_column_type(
155 col: &crate::introspect::ColumnInfo,
156 db_kind: DatabaseKind,
157 table: &TableInfo,
158 schema_info: &SchemaInfo,
159 type_overrides: &HashMap<String, String>,
160 time_crate: TimeCrate,
161) -> typemap::RustType {
162 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
164 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
165 let rt = typemap::RustType::with_import(
166 &enum_type_name,
167 &format!("use super::types::{};", enum_type_name),
168 );
169 return if col.is_nullable {
170 rt.wrap_option()
171 } else {
172 rt
173 };
174 }
175
176 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use crate::codegen::parse_and_format;
183 use crate::introspect::ColumnInfo;
184
185 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
186 TableInfo {
187 schema_name: "public".to_string(),
188 name: name.to_string(),
189 columns,
190 }
191 }
192
193 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
194 ColumnInfo {
195 name: name.to_string(),
196 data_type: udt_name.to_string(),
197 udt_name: udt_name.to_string(),
198 is_nullable: nullable,
199 is_primary_key: false,
200 ordinal_position: 0,
201 schema_name: "public".to_string(),
202 column_default: None,
203 }
204 }
205
206 fn gen(table: &TableInfo) -> String {
207 let schema = SchemaInfo::default();
208 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false, TimeCrate::Chrono);
209 parse_and_format(&tokens)
210 }
211
212 fn gen_with(
213 table: &TableInfo,
214 schema: &SchemaInfo,
215 db: DatabaseKind,
216 derives: &[String],
217 overrides: &HashMap<String, String>,
218 ) -> (String, BTreeSet<String>) {
219 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false, TimeCrate::Chrono);
220 (parse_and_format(&tokens), imports)
221 }
222
223 #[test]
226 fn test_simple_table() {
227 let table = make_table("users", vec![
228 make_col("id", "int4", false),
229 make_col("name", "text", false),
230 ]);
231 let code = gen(&table);
232 assert!(code.contains("pub id: i32"));
233 assert!(code.contains("pub name: String"));
234 }
235
236 #[test]
237 fn test_struct_name_pascal_case() {
238 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
239 let code = gen(&table);
240 assert!(code.contains("pub struct UserRoles"));
241 }
242
243 #[test]
244 fn test_struct_name_simple() {
245 let table = make_table("users", vec![make_col("id", "int4", false)]);
246 let code = gen(&table);
247 assert!(code.contains("pub struct Users"));
248 }
249
250 #[test]
253 fn test_nullable_column() {
254 let table = make_table("users", vec![make_col("email", "text", true)]);
255 let code = gen(&table);
256 assert!(code.contains("pub email: Option<String>"));
257 }
258
259 #[test]
260 fn test_non_nullable_column() {
261 let table = make_table("users", vec![make_col("name", "text", false)]);
262 let code = gen(&table);
263 assert!(code.contains("pub name: String"));
264 assert!(!code.contains("Option"));
265 }
266
267 #[test]
268 fn test_mix_nullable() {
269 let table = make_table("users", vec![
270 make_col("id", "int4", false),
271 make_col("bio", "text", true),
272 ]);
273 let code = gen(&table);
274 assert!(code.contains("pub id: i32"));
275 assert!(code.contains("pub bio: Option<String>"));
276 }
277
278 #[test]
281 fn test_keyword_type_renamed() {
282 let table = make_table("connector", vec![make_col("type", "text", false)]);
283 let code = gen(&table);
284 assert!(code.contains("pub connector_type: String"));
285 assert!(code.contains("sqlx(rename = \"type\")"));
286 }
287
288 #[test]
289 fn test_keyword_fn_renamed() {
290 let table = make_table("item", vec![make_col("fn", "text", false)]);
291 let code = gen(&table);
292 assert!(code.contains("pub item_fn: String"));
293 assert!(code.contains("sqlx(rename = \"fn\")"));
294 }
295
296 #[test]
297 fn test_keyword_match_renamed() {
298 let table = make_table("game", vec![make_col("match", "text", false)]);
299 let code = gen(&table);
300 assert!(code.contains("pub game_match: String"));
301 }
302
303 #[test]
304 fn test_non_keyword_no_rename() {
305 let table = make_table("users", vec![make_col("name", "text", false)]);
306 let code = gen(&table);
307 assert!(!code.contains("sqlx(rename"));
308 }
309
310 #[test]
313 fn test_camel_case_column_renamed() {
314 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
315 let code = gen(&table);
316 assert!(code.contains("pub created_at: String"));
317 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
318 }
319
320 #[test]
321 fn test_mixed_case_column_renamed() {
322 let table = make_table("users", vec![make_col("firstName", "text", false)]);
323 let code = gen(&table);
324 assert!(code.contains("pub first_name: String"));
325 assert!(code.contains("sqlx(rename = \"firstName\")"));
326 }
327
328 #[test]
329 fn test_already_snake_case_no_rename() {
330 let table = make_table("users", vec![make_col("created_at", "text", false)]);
331 let code = gen(&table);
332 assert!(code.contains("pub created_at: String"));
333 assert!(!code.contains("sqlx(rename"));
334 }
335
336 #[test]
339 fn test_default_derives() {
340 let table = make_table("users", vec![make_col("id", "int4", false)]);
341 let code = gen(&table);
342 assert!(code.contains("Debug"));
343 assert!(code.contains("Clone"));
344 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
345 }
346
347 #[test]
348 fn test_extra_derive_serialize() {
349 let table = make_table("users", vec![make_col("id", "int4", false)]);
350 let schema = SchemaInfo::default();
351 let derives = vec!["Serialize".to_string()];
352 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
353 assert!(code.contains("Serialize"));
354 }
355
356 #[test]
357 fn test_extra_derives_both_serde() {
358 let table = make_table("users", vec![make_col("id", "int4", false)]);
359 let schema = SchemaInfo::default();
360 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
361 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
362 assert!(imports.iter().any(|i| i.contains("serde")));
363 }
364
365 #[test]
368 fn test_uuid_import() {
369 let table = make_table("users", vec![make_col("id", "uuid", false)]);
370 let schema = SchemaInfo::default();
371 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
372 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
373 }
374
375 #[test]
376 fn test_timestamptz_import() {
377 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
378 let schema = SchemaInfo::default();
379 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
380 assert!(imports.iter().any(|i| i.contains("chrono")));
381 }
382
383 #[test]
384 fn test_int4_only_serde_import() {
385 let table = make_table("users", vec![make_col("id", "int4", false)]);
386 let schema = SchemaInfo::default();
387 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
388 assert_eq!(imports.len(), 2);
389 assert!(imports.iter().any(|i| i.contains("serde")));
390 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
391 }
392
393 #[test]
394 fn test_multiple_imports_collected() {
395 let table = make_table("users", vec![
396 make_col("id", "uuid", false),
397 make_col("created_at", "timestamptz", false),
398 ]);
399 let schema = SchemaInfo::default();
400 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
401 assert!(imports.iter().any(|i| i.contains("uuid")));
402 assert!(imports.iter().any(|i| i.contains("chrono")));
403 }
404
405 #[test]
408 fn test_mysql_enum_column() {
409 let table = make_table("users", vec![ColumnInfo {
410 name: "status".to_string(),
411 data_type: "enum".to_string(),
412 udt_name: "enum('active','inactive')".to_string(),
413 is_nullable: false,
414 is_primary_key: false,
415 ordinal_position: 0,
416 schema_name: "test_db".to_string(),
417 column_default: None,
418 }]);
419 let schema = SchemaInfo::default();
420 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
421 assert!(code.contains("UsersStatus"));
422 assert!(imports.iter().any(|i| i.contains("super::types::")));
423 }
424
425 #[test]
426 fn test_mysql_enum_nullable() {
427 let table = make_table("users", vec![ColumnInfo {
428 name: "status".to_string(),
429 data_type: "enum".to_string(),
430 udt_name: "enum('a','b')".to_string(),
431 is_nullable: true,
432 is_primary_key: false,
433 ordinal_position: 0,
434 schema_name: "test_db".to_string(),
435 column_default: None,
436 }]);
437 let schema = SchemaInfo::default();
438 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
439 assert!(code.contains("Option<UsersStatus>"));
440 }
441
442 #[test]
445 fn test_type_override() {
446 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
447 let schema = SchemaInfo::default();
448 let mut overrides = HashMap::new();
449 overrides.insert("jsonb".to_string(), "MyJson".to_string());
450 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
451 assert!(code.contains("pub data: MyJson"));
452 }
453
454 #[test]
455 fn test_type_override_absent() {
456 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
457 let schema = SchemaInfo::default();
458 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
459 assert!(code.contains("Value"));
460 }
461
462 #[test]
463 fn test_type_override_nullable() {
464 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
465 let schema = SchemaInfo::default();
466 let mut overrides = HashMap::new();
467 overrides.insert("jsonb".to_string(), "MyJson".to_string());
468 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
469 assert!(code.contains("Option<MyJson>"));
470 }
471}