1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword};
10use crate::introspect::{SchemaInfo, TableInfo};
11use crate::typemap;
12
13pub fn generate_struct(
14 table: &TableInfo,
15 db_kind: DatabaseKind,
16 schema_info: &SchemaInfo,
17 extra_derives: &[String],
18 type_overrides: &HashMap<String, String>,
19 is_view: bool,
20 time_crate: TimeCrate,
21) -> (TokenStream, BTreeSet<String>) {
22 let mut imports = BTreeSet::new();
23 for imp in imports_for_derives(extra_derives) {
24 imports.insert(imp);
25 }
26 let struct_name = format_ident!("{}", singularize(&table.name).to_upper_camel_case());
29
30 imports.insert("use serde::{Serialize, Deserialize};".to_string());
32 imports.insert("use sqlx_gen::SqlxGen;".to_string());
33 let mut derive_tokens = vec![
34 quote! { Debug },
35 quote! { Clone },
36 quote! { PartialEq },
37 quote! { Eq },
38 quote! { Serialize },
39 quote! { Deserialize },
40 quote! { sqlx::FromRow },
41 quote! { SqlxGen },
42 ];
43 for d in extra_derives {
44 let ident = format_ident!("{}", d);
45 derive_tokens.push(quote! { #ident });
46 }
47
48 let fields: Vec<TokenStream> = table
50 .columns
51 .iter()
52 .map(|col| {
53 let rust_type =
54 resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
55 if let Some(imp) = &rust_type.needs_import {
56 imports.insert(imp.clone());
57 }
58
59 let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
60 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
64 let prefix = singularize(&table.name).to_snake_case();
65 let prefixed = format!("{}_{}", prefix, field_name_snake);
66 (prefixed, true)
67 } else {
68 let changed = field_name_snake != col.name;
69 (field_name_snake, changed)
70 };
71
72 let field_ident = format_ident!("{}", effective_name);
73 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
74 let fallback = format_ident!("String");
75 quote! { #fallback }
76 });
77
78 let rename = if needs_rename {
79 let original = &col.name;
80 quote! { #[sqlx(rename = #original)] }
81 } else {
82 quote! {}
83 };
84
85 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
87 let has_pk = col.is_primary_key;
88 let has_sql_type = sql_type.is_some();
89 let has_default = col.column_default.is_some();
90
91 let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
92 let pk_part = if has_pk {
93 quote! { primary_key, }
94 } else {
95 quote! {}
96 };
97 let sql_type_part = match &sql_type {
98 Some(t) => quote! { sql_type = #t, },
99 None => quote! {},
100 };
101 let array_part = if is_sql_array {
102 quote! { is_array, }
103 } else {
104 quote! {}
105 };
106 let default_part = match &col.column_default {
107 Some(d) => quote! { column_default = #d, },
108 None => quote! {},
109 };
110 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
111 } else {
112 quote! {}
113 };
114
115 quote! {
116 #sqlx_gen_attr
117 #rename
118 pub #field_ident: #type_tokens,
119 }
120 })
121 .collect();
122
123 let table_name_str = &table.name;
124 let schema_name_str = &table.schema_name;
125 let kind_str = if is_view { "view" } else { "table" };
126
127 let tokens = quote! {
128 #[derive(#(#derive_tokens),*)]
129 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
130 pub struct #struct_name {
131 #(#fields)*
132 }
133 };
134
135 (tokens, imports)
136}
137
138pub(crate) fn sanitize_rust_ident(name: &str) -> String {
147 if name.is_empty() {
148 return "_field".to_string();
149 }
150 let mut out: String = name
151 .chars()
152 .map(|c| {
153 if c.is_ascii_alphanumeric() || c == '_' {
154 c
155 } else {
156 '_'
157 }
158 })
159 .collect();
160 if out.starts_with(|c: char| c.is_ascii_digit()) {
161 out.insert(0, '_');
162 }
163 out
164}
165
166fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
171 let (base_name, is_array) = match udt_name.strip_prefix('_') {
172 Some(inner) => (inner, true),
173 None => (udt_name, false),
174 };
175
176 if schema_info.enums.iter().any(|e| e.name == base_name) {
178 return (Some(base_name.to_string()), is_array);
179 }
180
181 if schema_info
183 .composite_types
184 .iter()
185 .any(|c| c.name == base_name)
186 {
187 return (Some(base_name.to_string()), is_array);
188 }
189
190 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
194 if !is_domain && !typemap::postgres::is_builtin(base_name) {
195 return (Some(base_name.to_string()), is_array);
196 }
197
198 if is_array {
202 return (Some(base_name.to_string()), true);
203 }
204
205 (None, false)
206}
207
208fn resolve_column_type(
209 col: &crate::introspect::ColumnInfo,
210 db_kind: DatabaseKind,
211 table: &TableInfo,
212 schema_info: &SchemaInfo,
213 type_overrides: &HashMap<String, String>,
214 time_crate: TimeCrate,
215) -> typemap::RustType {
216 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
218 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
219 let rt = typemap::RustType::with_import(
220 &enum_type_name,
221 &format!("use super::types::{};", enum_type_name),
222 );
223 return if col.is_nullable {
224 rt.wrap_option()
225 } else {
226 rt
227 };
228 }
229
230 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use crate::codegen::parse_and_format;
237 use crate::introspect::ColumnInfo;
238
239 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
240 TableInfo {
241 schema_name: "public".to_string(),
242 name: name.to_string(),
243 columns,
244 }
245 }
246
247 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
248 ColumnInfo {
249 name: name.to_string(),
250 data_type: udt_name.to_string(),
251 udt_name: udt_name.to_string(),
252 is_nullable: nullable,
253 is_primary_key: false,
254 ordinal_position: 0,
255 schema_name: "public".to_string(),
256 udt_schema: None,
257 column_default: None,
258 }
259 }
260
261 fn gen(table: &TableInfo) -> String {
262 let schema = SchemaInfo::default();
263 let (tokens, _) = generate_struct(
264 table,
265 DatabaseKind::Postgres,
266 &schema,
267 &[],
268 &HashMap::new(),
269 false,
270 TimeCrate::Chrono,
271 );
272 parse_and_format(&tokens).unwrap()
273 }
274
275 fn gen_with(
276 table: &TableInfo,
277 schema: &SchemaInfo,
278 db: DatabaseKind,
279 derives: &[String],
280 overrides: &HashMap<String, String>,
281 ) -> (String, BTreeSet<String>) {
282 let (tokens, imports) = generate_struct(
283 table,
284 db,
285 schema,
286 derives,
287 overrides,
288 false,
289 TimeCrate::Chrono,
290 );
291 (parse_and_format(&tokens).unwrap(), imports)
292 }
293
294 #[test]
297 fn test_simple_table() {
298 let table = make_table(
299 "users",
300 vec![
301 make_col("id", "int4", false),
302 make_col("name", "text", false),
303 ],
304 );
305 let code = gen(&table);
306 assert!(code.contains("pub id: i32"));
307 assert!(code.contains("pub name: String"));
308 }
309
310 #[test]
311 fn test_struct_name_pascal_case_and_singular() {
312 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
313 let code = gen(&table);
314 assert!(
316 code.contains("pub struct UserRole"),
317 "expected singular PascalCase struct name, got:\n{}",
318 code
319 );
320 assert!(!code.contains("pub struct UserRoles"));
321 }
322
323 #[test]
324 fn test_struct_name_is_singular() {
325 let table = make_table("users", vec![make_col("id", "int4", false)]);
326 let code = gen(&table);
327 assert!(
328 code.contains("pub struct User"),
329 "table 'users' must produce singular 'User' struct, got:\n{}",
330 code
331 );
332 assert!(!code.contains("pub struct Users"));
333 }
334
335 #[test]
336 fn test_struct_name_already_singular_unchanged() {
337 let table = make_table("agent_connector", vec![make_col("id", "int4", false)]);
338 let code = gen(&table);
339 assert!(code.contains("pub struct AgentConnector"));
340 }
341
342 #[test]
343 fn test_struct_name_uncountable_unchanged() {
344 let table = make_table("news", vec![make_col("id", "int4", false)]);
345 let code = gen(&table);
346 assert!(code.contains("pub struct News"));
347 }
348
349 #[test]
350 fn test_reserved_keyword_column_prefixed_with_singular_table() {
351 let table = make_table(
353 "products",
354 vec![
355 make_col("id", "int4", false),
356 make_col("type", "text", false),
357 ],
358 );
359 let code = gen(&table);
360 assert!(
361 code.contains("pub product_type:"),
362 "expected singularized prefix 'product_type', got:\n{}",
363 code
364 );
365 assert!(
366 !code.contains("pub products_type:"),
367 "must not use plural-form prefix, got:\n{}",
368 code
369 );
370 assert!(code.contains("sqlx(rename = \"type\")"));
372 }
373
374 #[test]
375 fn test_reserved_keyword_column_on_already_singular_table() {
376 let table = make_table(
377 "connector",
378 vec![
379 make_col("id", "int4", false),
380 make_col("type", "text", false),
381 ],
382 );
383 let code = gen(&table);
384 assert!(code.contains("pub connector_type:"));
385 }
386
387 #[test]
390 fn test_nullable_column() {
391 let table = make_table("users", vec![make_col("email", "text", true)]);
392 let code = gen(&table);
393 assert!(code.contains("pub email: Option<String>"));
394 }
395
396 #[test]
397 fn test_non_nullable_column() {
398 let table = make_table("users", vec![make_col("name", "text", false)]);
399 let code = gen(&table);
400 assert!(code.contains("pub name: String"));
401 assert!(!code.contains("Option"));
402 }
403
404 #[test]
405 fn test_mix_nullable() {
406 let table = make_table(
407 "users",
408 vec![make_col("id", "int4", false), make_col("bio", "text", true)],
409 );
410 let code = gen(&table);
411 assert!(code.contains("pub id: i32"));
412 assert!(code.contains("pub bio: Option<String>"));
413 }
414
415 #[test]
418 fn test_keyword_type_renamed() {
419 let table = make_table("connector", vec![make_col("type", "text", false)]);
420 let code = gen(&table);
421 assert!(code.contains("pub connector_type: String"));
422 assert!(code.contains("sqlx(rename = \"type\")"));
423 }
424
425 #[test]
426 fn test_keyword_fn_renamed() {
427 let table = make_table("item", vec![make_col("fn", "text", false)]);
428 let code = gen(&table);
429 assert!(code.contains("pub item_fn: String"));
430 assert!(code.contains("sqlx(rename = \"fn\")"));
431 }
432
433 #[test]
434 fn test_keyword_match_renamed() {
435 let table = make_table("game", vec![make_col("match", "text", false)]);
436 let code = gen(&table);
437 assert!(code.contains("pub game_match: String"));
438 }
439
440 #[test]
441 fn test_non_keyword_no_rename() {
442 let table = make_table("users", vec![make_col("name", "text", false)]);
443 let code = gen(&table);
444 assert!(!code.contains("sqlx(rename"));
445 }
446
447 #[test]
450 fn test_camel_case_column_renamed() {
451 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
452 let code = gen(&table);
453 assert!(code.contains("pub created_at: String"));
454 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
455 }
456
457 #[test]
458 fn test_mixed_case_column_renamed() {
459 let table = make_table("users", vec![make_col("firstName", "text", false)]);
460 let code = gen(&table);
461 assert!(code.contains("pub first_name: String"));
462 assert!(code.contains("sqlx(rename = \"firstName\")"));
463 }
464
465 #[test]
466 fn test_already_snake_case_no_rename() {
467 let table = make_table("users", vec![make_col("created_at", "text", false)]);
468 let code = gen(&table);
469 assert!(code.contains("pub created_at: String"));
470 assert!(!code.contains("sqlx(rename"));
471 }
472
473 #[test]
476 fn test_default_derives() {
477 let table = make_table("users", vec![make_col("id", "int4", false)]);
478 let code = gen(&table);
479 assert!(code.contains("Debug"));
480 assert!(code.contains("Clone"));
481 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
482 }
483
484 #[test]
485 fn test_extra_derive_serialize() {
486 let table = make_table("users", vec![make_col("id", "int4", false)]);
487 let schema = SchemaInfo::default();
488 let derives = vec!["Serialize".to_string()];
489 let (code, _) = gen_with(
490 &table,
491 &schema,
492 DatabaseKind::Postgres,
493 &derives,
494 &HashMap::new(),
495 );
496 assert!(code.contains("Serialize"));
497 }
498
499 #[test]
500 fn test_extra_derives_both_serde() {
501 let table = make_table("users", vec![make_col("id", "int4", false)]);
502 let schema = SchemaInfo::default();
503 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
504 let (_, imports) = gen_with(
505 &table,
506 &schema,
507 DatabaseKind::Postgres,
508 &derives,
509 &HashMap::new(),
510 );
511 assert!(imports.iter().any(|i| i.contains("serde")));
512 }
513
514 #[test]
517 fn test_uuid_import() {
518 let table = make_table("users", vec![make_col("id", "uuid", false)]);
519 let schema = SchemaInfo::default();
520 let (_, imports) = gen_with(
521 &table,
522 &schema,
523 DatabaseKind::Postgres,
524 &[],
525 &HashMap::new(),
526 );
527 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
528 }
529
530 #[test]
531 fn test_timestamptz_import() {
532 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
533 let schema = SchemaInfo::default();
534 let (_, imports) = gen_with(
535 &table,
536 &schema,
537 DatabaseKind::Postgres,
538 &[],
539 &HashMap::new(),
540 );
541 assert!(imports.iter().any(|i| i.contains("chrono")));
542 }
543
544 #[test]
545 fn test_int4_only_serde_import() {
546 let table = make_table("users", vec![make_col("id", "int4", false)]);
547 let schema = SchemaInfo::default();
548 let (_, imports) = gen_with(
549 &table,
550 &schema,
551 DatabaseKind::Postgres,
552 &[],
553 &HashMap::new(),
554 );
555 assert_eq!(imports.len(), 2);
556 assert!(imports.iter().any(|i| i.contains("serde")));
557 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
558 }
559
560 #[test]
561 fn test_multiple_imports_collected() {
562 let table = make_table(
563 "users",
564 vec![
565 make_col("id", "uuid", false),
566 make_col("created_at", "timestamptz", false),
567 ],
568 );
569 let schema = SchemaInfo::default();
570 let (_, imports) = gen_with(
571 &table,
572 &schema,
573 DatabaseKind::Postgres,
574 &[],
575 &HashMap::new(),
576 );
577 assert!(imports.iter().any(|i| i.contains("uuid")));
578 assert!(imports.iter().any(|i| i.contains("chrono")));
579 }
580
581 #[test]
584 fn test_mysql_enum_column() {
585 let table = make_table(
586 "users",
587 vec![ColumnInfo {
588 name: "status".to_string(),
589 data_type: "enum".to_string(),
590 udt_name: "enum('active','inactive')".to_string(),
591 is_nullable: false,
592 is_primary_key: false,
593 ordinal_position: 0,
594 schema_name: "test_db".to_string(),
595 udt_schema: None,
596 column_default: None,
597 }],
598 );
599 let schema = SchemaInfo::default();
600 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
601 assert!(code.contains("UsersStatus"));
602 assert!(imports.iter().any(|i| i.contains("super::types::")));
603 }
604
605 #[test]
606 fn test_mysql_enum_nullable() {
607 let table = make_table(
608 "users",
609 vec![ColumnInfo {
610 name: "status".to_string(),
611 data_type: "enum".to_string(),
612 udt_name: "enum('a','b')".to_string(),
613 is_nullable: true,
614 is_primary_key: false,
615 ordinal_position: 0,
616 schema_name: "test_db".to_string(),
617 udt_schema: None,
618 column_default: None,
619 }],
620 );
621 let schema = SchemaInfo::default();
622 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
623 assert!(code.contains("Option<UsersStatus>"));
624 }
625
626 #[test]
629 fn test_type_override() {
630 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
631 let schema = SchemaInfo::default();
632 let mut overrides = HashMap::new();
633 overrides.insert("jsonb".to_string(), "MyJson".to_string());
634 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
635 assert!(code.contains("pub data: MyJson"));
636 }
637
638 #[test]
639 fn test_type_override_absent() {
640 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
641 let schema = SchemaInfo::default();
642 let (code, _) = gen_with(
643 &table,
644 &schema,
645 DatabaseKind::Postgres,
646 &[],
647 &HashMap::new(),
648 );
649 assert!(code.contains("Value"));
650 }
651
652 #[test]
653 fn test_type_override_nullable() {
654 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
655 let schema = SchemaInfo::default();
656 let mut overrides = HashMap::new();
657 overrides.insert("jsonb".to_string(), "MyJson".to_string());
658 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
659 assert!(code.contains("Option<MyJson>"));
660 }
661
662 #[test]
665 fn test_native_array_text_gets_sql_type_annotation() {
666 let table = make_table("posts", vec![make_col("tags", "_text", false)]);
667 let code = gen(&table);
668 assert!(code.contains("Vec<String>"));
669 assert!(code.contains("sql_type = \"text\""));
670 assert!(code.contains("is_array"));
671 }
672
673 #[test]
674 fn test_native_array_int4_gets_sql_type_annotation() {
675 let table = make_table("data", vec![make_col("values", "_int4", false)]);
676 let code = gen(&table);
677 assert!(code.contains("Vec<i32>"));
678 assert!(code.contains("sql_type = \"int4\""));
679 assert!(code.contains("is_array"));
680 }
681
682 #[test]
683 fn test_native_array_nullable_gets_sql_type_annotation() {
684 let table = make_table("posts", vec![make_col("tags", "_text", true)]);
685 let code = gen(&table);
686 assert!(code.contains("Option<Vec<String>>"));
687 assert!(code.contains("sql_type = \"text\""));
688 assert!(code.contains("is_array"));
689 }
690
691 #[test]
692 fn test_scalar_builtin_no_sql_type_annotation() {
693 let table = make_table("users", vec![make_col("name", "text", false)]);
694 let code = gen(&table);
695 assert!(code.contains("pub name: String"));
696 assert!(!code.contains("sql_type"));
697 }
698
699 #[test]
702 fn test_sanitize_replaces_dash() {
703 assert_eq!(sanitize_rust_ident("user-id"), "user_id");
704 }
705
706 #[test]
707 fn test_sanitize_replaces_space() {
708 assert_eq!(sanitize_rust_ident("created at"), "created_at");
709 }
710
711 #[test]
712 fn test_sanitize_replaces_dot() {
713 assert_eq!(sanitize_rust_ident("a.b"), "a_b");
714 }
715
716 #[test]
717 fn test_sanitize_prefixes_leading_digit() {
718 assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
719 }
720
721 #[test]
722 fn test_sanitize_empty_becomes_placeholder() {
723 assert_eq!(sanitize_rust_ident(""), "_field");
724 }
725
726 #[test]
727 fn test_sanitize_leaves_valid_ident_unchanged() {
728 assert_eq!(sanitize_rust_ident("user_id"), "user_id");
729 assert_eq!(sanitize_rust_ident("_private"), "_private");
730 }
731
732 #[test]
733 fn test_column_with_dash_generates_valid_rust() {
734 let table = make_table("users", vec![make_col("user-id", "int4", false)]);
735 let code = gen(&table);
736 assert!(
738 code.contains("pub user_id:") || code.contains("user_id:"),
739 "expected sanitized identifier, got:\n{}",
740 code
741 );
742 assert!(code.contains("sqlx(rename = \"user-id\")"));
743 }
744}