1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 is_view: bool,
19 time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21 let mut imports = BTreeSet::new();
22 for imp in imports_for_derives(extra_derives) {
23 imports.insert(imp);
24 }
25 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27 imports.insert("use serde::{Serialize, Deserialize};".to_string());
29 imports.insert("use sqlx_gen::SqlxGen;".to_string());
30 let mut derive_tokens = vec![
31 quote! { Debug },
32 quote! { Clone },
33 quote! { PartialEq },
34 quote! { Eq },
35 quote! { Serialize },
36 quote! { Deserialize },
37 quote! { sqlx::FromRow },
38 quote! { SqlxGen },
39 ];
40 for d in extra_derives {
41 let ident = format_ident!("{}", d);
42 derive_tokens.push(quote! { #ident });
43 }
44
45 let fields: Vec<TokenStream> = table
47 .columns
48 .iter()
49 .map(|col| {
50 let rust_type = resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
51 if let Some(imp) = &rust_type.needs_import {
52 imports.insert(imp.clone());
53 }
54
55 let field_name_snake = col.name.to_snake_case();
56 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
59 let prefixed = format!(
60 "{}_{}",
61 table.name.to_snake_case(),
62 field_name_snake
63 );
64 (prefixed, true)
65 } else {
66 let changed = field_name_snake != col.name;
67 (field_name_snake, changed)
68 };
69
70 let field_ident = format_ident!("{}", effective_name);
71 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72 let fallback = format_ident!("String");
73 quote! { #fallback }
74 });
75
76 let rename = if needs_rename {
77 let original = &col.name;
78 quote! { #[sqlx(rename = #original)] }
79 } else {
80 quote! {}
81 };
82
83 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85 let has_pk = col.is_primary_key;
86 let has_sql_type = sql_type.is_some();
87 let has_default = col.column_default.is_some();
88
89 let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
90 let pk_part = if has_pk { quote! { primary_key, } } else { quote! {} };
91 let sql_type_part = match &sql_type {
92 Some(t) => quote! { sql_type = #t, },
93 None => quote! {},
94 };
95 let array_part = if is_sql_array { quote! { is_array, } } else { quote! {} };
96 let default_part = match &col.column_default {
97 Some(d) => quote! { column_default = #d, },
98 None => quote! {},
99 };
100 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
101 } else {
102 quote! {}
103 };
104
105 quote! {
106 #sqlx_gen_attr
107 #rename
108 pub #field_ident: #type_tokens,
109 }
110 })
111 .collect();
112
113 let table_name_str = &table.name;
114 let schema_name_str = &table.schema_name;
115 let kind_str = if is_view { "view" } else { "table" };
116
117 let tokens = quote! {
118 #[derive(#(#derive_tokens),*)]
119 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
120 pub struct #struct_name {
121 #(#fields)*
122 }
123 };
124
125 (tokens, imports)
126}
127
128fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
133 let (base_name, is_array) = match udt_name.strip_prefix('_') {
134 Some(inner) => (inner, true),
135 None => (udt_name, false),
136 };
137
138 if schema_info.enums.iter().any(|e| e.name == base_name) {
140 return (Some(base_name.to_string()), is_array);
141 }
142
143 if schema_info.composite_types.iter().any(|c| c.name == base_name) {
145 return (Some(base_name.to_string()), is_array);
146 }
147
148 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
152 if !is_domain && !typemap::postgres::is_builtin(base_name) {
153 return (Some(base_name.to_string()), is_array);
154 }
155
156 if is_array {
160 return (Some(base_name.to_string()), true);
161 }
162
163 (None, false)
164}
165
166fn resolve_column_type(
167 col: &crate::introspect::ColumnInfo,
168 db_kind: DatabaseKind,
169 table: &TableInfo,
170 schema_info: &SchemaInfo,
171 type_overrides: &HashMap<String, String>,
172 time_crate: TimeCrate,
173) -> typemap::RustType {
174 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
176 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
177 let rt = typemap::RustType::with_import(
178 &enum_type_name,
179 &format!("use super::types::{};", enum_type_name),
180 );
181 return if col.is_nullable {
182 rt.wrap_option()
183 } else {
184 rt
185 };
186 }
187
188 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::codegen::parse_and_format;
195 use crate::introspect::ColumnInfo;
196
197 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
198 TableInfo {
199 schema_name: "public".to_string(),
200 name: name.to_string(),
201 columns,
202 }
203 }
204
205 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
206 ColumnInfo {
207 name: name.to_string(),
208 data_type: udt_name.to_string(),
209 udt_name: udt_name.to_string(),
210 is_nullable: nullable,
211 is_primary_key: false,
212 ordinal_position: 0,
213 schema_name: "public".to_string(),
214 column_default: None,
215 }
216 }
217
218 fn gen(table: &TableInfo) -> String {
219 let schema = SchemaInfo::default();
220 let (tokens, _) = generate_struct(table, DatabaseKind::Postgres, &schema, &[], &HashMap::new(), false, TimeCrate::Chrono);
221 parse_and_format(&tokens)
222 }
223
224 fn gen_with(
225 table: &TableInfo,
226 schema: &SchemaInfo,
227 db: DatabaseKind,
228 derives: &[String],
229 overrides: &HashMap<String, String>,
230 ) -> (String, BTreeSet<String>) {
231 let (tokens, imports) = generate_struct(table, db, schema, derives, overrides, false, TimeCrate::Chrono);
232 (parse_and_format(&tokens), imports)
233 }
234
235 #[test]
238 fn test_simple_table() {
239 let table = make_table("users", vec![
240 make_col("id", "int4", false),
241 make_col("name", "text", false),
242 ]);
243 let code = gen(&table);
244 assert!(code.contains("pub id: i32"));
245 assert!(code.contains("pub name: String"));
246 }
247
248 #[test]
249 fn test_struct_name_pascal_case() {
250 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
251 let code = gen(&table);
252 assert!(code.contains("pub struct UserRoles"));
253 }
254
255 #[test]
256 fn test_struct_name_simple() {
257 let table = make_table("users", vec![make_col("id", "int4", false)]);
258 let code = gen(&table);
259 assert!(code.contains("pub struct Users"));
260 }
261
262 #[test]
265 fn test_nullable_column() {
266 let table = make_table("users", vec![make_col("email", "text", true)]);
267 let code = gen(&table);
268 assert!(code.contains("pub email: Option<String>"));
269 }
270
271 #[test]
272 fn test_non_nullable_column() {
273 let table = make_table("users", vec![make_col("name", "text", false)]);
274 let code = gen(&table);
275 assert!(code.contains("pub name: String"));
276 assert!(!code.contains("Option"));
277 }
278
279 #[test]
280 fn test_mix_nullable() {
281 let table = make_table("users", vec![
282 make_col("id", "int4", false),
283 make_col("bio", "text", true),
284 ]);
285 let code = gen(&table);
286 assert!(code.contains("pub id: i32"));
287 assert!(code.contains("pub bio: Option<String>"));
288 }
289
290 #[test]
293 fn test_keyword_type_renamed() {
294 let table = make_table("connector", vec![make_col("type", "text", false)]);
295 let code = gen(&table);
296 assert!(code.contains("pub connector_type: String"));
297 assert!(code.contains("sqlx(rename = \"type\")"));
298 }
299
300 #[test]
301 fn test_keyword_fn_renamed() {
302 let table = make_table("item", vec![make_col("fn", "text", false)]);
303 let code = gen(&table);
304 assert!(code.contains("pub item_fn: String"));
305 assert!(code.contains("sqlx(rename = \"fn\")"));
306 }
307
308 #[test]
309 fn test_keyword_match_renamed() {
310 let table = make_table("game", vec![make_col("match", "text", false)]);
311 let code = gen(&table);
312 assert!(code.contains("pub game_match: String"));
313 }
314
315 #[test]
316 fn test_non_keyword_no_rename() {
317 let table = make_table("users", vec![make_col("name", "text", false)]);
318 let code = gen(&table);
319 assert!(!code.contains("sqlx(rename"));
320 }
321
322 #[test]
325 fn test_camel_case_column_renamed() {
326 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
327 let code = gen(&table);
328 assert!(code.contains("pub created_at: String"));
329 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
330 }
331
332 #[test]
333 fn test_mixed_case_column_renamed() {
334 let table = make_table("users", vec![make_col("firstName", "text", false)]);
335 let code = gen(&table);
336 assert!(code.contains("pub first_name: String"));
337 assert!(code.contains("sqlx(rename = \"firstName\")"));
338 }
339
340 #[test]
341 fn test_already_snake_case_no_rename() {
342 let table = make_table("users", vec![make_col("created_at", "text", false)]);
343 let code = gen(&table);
344 assert!(code.contains("pub created_at: String"));
345 assert!(!code.contains("sqlx(rename"));
346 }
347
348 #[test]
351 fn test_default_derives() {
352 let table = make_table("users", vec![make_col("id", "int4", false)]);
353 let code = gen(&table);
354 assert!(code.contains("Debug"));
355 assert!(code.contains("Clone"));
356 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
357 }
358
359 #[test]
360 fn test_extra_derive_serialize() {
361 let table = make_table("users", vec![make_col("id", "int4", false)]);
362 let schema = SchemaInfo::default();
363 let derives = vec!["Serialize".to_string()];
364 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
365 assert!(code.contains("Serialize"));
366 }
367
368 #[test]
369 fn test_extra_derives_both_serde() {
370 let table = make_table("users", vec![make_col("id", "int4", false)]);
371 let schema = SchemaInfo::default();
372 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
373 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &derives, &HashMap::new());
374 assert!(imports.iter().any(|i| i.contains("serde")));
375 }
376
377 #[test]
380 fn test_uuid_import() {
381 let table = make_table("users", vec![make_col("id", "uuid", false)]);
382 let schema = SchemaInfo::default();
383 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
384 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
385 }
386
387 #[test]
388 fn test_timestamptz_import() {
389 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
390 let schema = SchemaInfo::default();
391 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
392 assert!(imports.iter().any(|i| i.contains("chrono")));
393 }
394
395 #[test]
396 fn test_int4_only_serde_import() {
397 let table = make_table("users", vec![make_col("id", "int4", false)]);
398 let schema = SchemaInfo::default();
399 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
400 assert_eq!(imports.len(), 2);
401 assert!(imports.iter().any(|i| i.contains("serde")));
402 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
403 }
404
405 #[test]
406 fn test_multiple_imports_collected() {
407 let table = make_table("users", vec![
408 make_col("id", "uuid", false),
409 make_col("created_at", "timestamptz", false),
410 ]);
411 let schema = SchemaInfo::default();
412 let (_, imports) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
413 assert!(imports.iter().any(|i| i.contains("uuid")));
414 assert!(imports.iter().any(|i| i.contains("chrono")));
415 }
416
417 #[test]
420 fn test_mysql_enum_column() {
421 let table = make_table("users", vec![ColumnInfo {
422 name: "status".to_string(),
423 data_type: "enum".to_string(),
424 udt_name: "enum('active','inactive')".to_string(),
425 is_nullable: false,
426 is_primary_key: false,
427 ordinal_position: 0,
428 schema_name: "test_db".to_string(),
429 column_default: None,
430 }]);
431 let schema = SchemaInfo::default();
432 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
433 assert!(code.contains("UsersStatus"));
434 assert!(imports.iter().any(|i| i.contains("super::types::")));
435 }
436
437 #[test]
438 fn test_mysql_enum_nullable() {
439 let table = make_table("users", vec![ColumnInfo {
440 name: "status".to_string(),
441 data_type: "enum".to_string(),
442 udt_name: "enum('a','b')".to_string(),
443 is_nullable: true,
444 is_primary_key: false,
445 ordinal_position: 0,
446 schema_name: "test_db".to_string(),
447 column_default: None,
448 }]);
449 let schema = SchemaInfo::default();
450 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
451 assert!(code.contains("Option<UsersStatus>"));
452 }
453
454 #[test]
457 fn test_type_override() {
458 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
459 let schema = SchemaInfo::default();
460 let mut overrides = HashMap::new();
461 overrides.insert("jsonb".to_string(), "MyJson".to_string());
462 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
463 assert!(code.contains("pub data: MyJson"));
464 }
465
466 #[test]
467 fn test_type_override_absent() {
468 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
469 let schema = SchemaInfo::default();
470 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &HashMap::new());
471 assert!(code.contains("Value"));
472 }
473
474 #[test]
475 fn test_type_override_nullable() {
476 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
477 let schema = SchemaInfo::default();
478 let mut overrides = HashMap::new();
479 overrides.insert("jsonb".to_string(), "MyJson".to_string());
480 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
481 assert!(code.contains("Option<MyJson>"));
482 }
483
484 #[test]
487 fn test_native_array_text_gets_sql_type_annotation() {
488 let table = make_table("posts", vec![make_col("tags", "_text", false)]);
489 let code = gen(&table);
490 assert!(code.contains("Vec<String>"));
491 assert!(code.contains("sql_type = \"text\""));
492 assert!(code.contains("is_array"));
493 }
494
495 #[test]
496 fn test_native_array_int4_gets_sql_type_annotation() {
497 let table = make_table("data", vec![make_col("values", "_int4", false)]);
498 let code = gen(&table);
499 assert!(code.contains("Vec<i32>"));
500 assert!(code.contains("sql_type = \"int4\""));
501 assert!(code.contains("is_array"));
502 }
503
504 #[test]
505 fn test_native_array_nullable_gets_sql_type_annotation() {
506 let table = make_table("posts", vec![make_col("tags", "_text", true)]);
507 let code = gen(&table);
508 assert!(code.contains("Option<Vec<String>>"));
509 assert!(code.contains("sql_type = \"text\""));
510 assert!(code.contains("is_array"));
511 }
512
513 #[test]
514 fn test_scalar_builtin_no_sql_type_annotation() {
515 let table = make_table("users", vec![make_col("name", "text", false)]);
516 let code = gen(&table);
517 assert!(code.contains("pub name: String"));
518 assert!(!code.contains("sql_type"));
519 }
520}