1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::naming::singularize;
9use crate::codegen::{imports_for_derives, is_rust_keyword};
10use crate::introspect::{SchemaInfo, TableInfo};
11use crate::typemap;
12
13pub fn generate_struct(
14 table: &TableInfo,
15 db_kind: DatabaseKind,
16 schema_info: &SchemaInfo,
17 extra_derives: &[String],
18 type_overrides: &HashMap<String, String>,
19 is_view: bool,
20 time_crate: TimeCrate,
21) -> (TokenStream, BTreeSet<String>) {
22 let mut imports = BTreeSet::new();
23 for imp in imports_for_derives(extra_derives) {
24 imports.insert(imp);
25 }
26 let struct_name = format_ident!("{}", singularize(&table.name).to_upper_camel_case());
29
30 imports.insert("use serde::{Serialize, Deserialize};".to_string());
32 imports.insert("use sqlx_gen::SqlxGen;".to_string());
33 let mut derive_tokens = vec![
34 quote! { Debug },
35 quote! { Clone },
36 quote! { PartialEq },
37 quote! { Eq },
38 quote! { Serialize },
39 quote! { Deserialize },
40 quote! { sqlx::FromRow },
41 quote! { SqlxGen },
42 ];
43 for d in extra_derives {
44 let ident = format_ident!("{}", d);
45 derive_tokens.push(quote! { #ident });
46 }
47
48 let fields: Vec<TokenStream> = table
50 .columns
51 .iter()
52 .map(|col| {
53 let rust_type =
54 resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
55 if let Some(imp) = &rust_type.needs_import {
56 imports.insert(imp.clone());
57 }
58
59 let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
60 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
63 let prefixed = format!("{}_{}", table.name.to_snake_case(), field_name_snake);
64 (prefixed, true)
65 } else {
66 let changed = field_name_snake != col.name;
67 (field_name_snake, changed)
68 };
69
70 let field_ident = format_ident!("{}", effective_name);
71 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
72 let fallback = format_ident!("String");
73 quote! { #fallback }
74 });
75
76 let rename = if needs_rename {
77 let original = &col.name;
78 quote! { #[sqlx(rename = #original)] }
79 } else {
80 quote! {}
81 };
82
83 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
85 let has_pk = col.is_primary_key;
86 let has_sql_type = sql_type.is_some();
87 let has_default = col.column_default.is_some();
88
89 let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
90 let pk_part = if has_pk {
91 quote! { primary_key, }
92 } else {
93 quote! {}
94 };
95 let sql_type_part = match &sql_type {
96 Some(t) => quote! { sql_type = #t, },
97 None => quote! {},
98 };
99 let array_part = if is_sql_array {
100 quote! { is_array, }
101 } else {
102 quote! {}
103 };
104 let default_part = match &col.column_default {
105 Some(d) => quote! { column_default = #d, },
106 None => quote! {},
107 };
108 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
109 } else {
110 quote! {}
111 };
112
113 quote! {
114 #sqlx_gen_attr
115 #rename
116 pub #field_ident: #type_tokens,
117 }
118 })
119 .collect();
120
121 let table_name_str = &table.name;
122 let schema_name_str = &table.schema_name;
123 let kind_str = if is_view { "view" } else { "table" };
124
125 let tokens = quote! {
126 #[derive(#(#derive_tokens),*)]
127 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
128 pub struct #struct_name {
129 #(#fields)*
130 }
131 };
132
133 (tokens, imports)
134}
135
136pub(crate) fn sanitize_rust_ident(name: &str) -> String {
145 if name.is_empty() {
146 return "_field".to_string();
147 }
148 let mut out: String = name
149 .chars()
150 .map(|c| {
151 if c.is_ascii_alphanumeric() || c == '_' {
152 c
153 } else {
154 '_'
155 }
156 })
157 .collect();
158 if out.starts_with(|c: char| c.is_ascii_digit()) {
159 out.insert(0, '_');
160 }
161 out
162}
163
164fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
169 let (base_name, is_array) = match udt_name.strip_prefix('_') {
170 Some(inner) => (inner, true),
171 None => (udt_name, false),
172 };
173
174 if schema_info.enums.iter().any(|e| e.name == base_name) {
176 return (Some(base_name.to_string()), is_array);
177 }
178
179 if schema_info
181 .composite_types
182 .iter()
183 .any(|c| c.name == base_name)
184 {
185 return (Some(base_name.to_string()), is_array);
186 }
187
188 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
192 if !is_domain && !typemap::postgres::is_builtin(base_name) {
193 return (Some(base_name.to_string()), is_array);
194 }
195
196 if is_array {
200 return (Some(base_name.to_string()), true);
201 }
202
203 (None, false)
204}
205
206fn resolve_column_type(
207 col: &crate::introspect::ColumnInfo,
208 db_kind: DatabaseKind,
209 table: &TableInfo,
210 schema_info: &SchemaInfo,
211 type_overrides: &HashMap<String, String>,
212 time_crate: TimeCrate,
213) -> typemap::RustType {
214 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
216 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
217 let rt = typemap::RustType::with_import(
218 &enum_type_name,
219 &format!("use super::types::{};", enum_type_name),
220 );
221 return if col.is_nullable {
222 rt.wrap_option()
223 } else {
224 rt
225 };
226 }
227
228 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use crate::codegen::parse_and_format;
235 use crate::introspect::ColumnInfo;
236
237 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
238 TableInfo {
239 schema_name: "public".to_string(),
240 name: name.to_string(),
241 columns,
242 }
243 }
244
245 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
246 ColumnInfo {
247 name: name.to_string(),
248 data_type: udt_name.to_string(),
249 udt_name: udt_name.to_string(),
250 is_nullable: nullable,
251 is_primary_key: false,
252 ordinal_position: 0,
253 schema_name: "public".to_string(),
254 udt_schema: None,
255 column_default: None,
256 }
257 }
258
259 fn gen(table: &TableInfo) -> String {
260 let schema = SchemaInfo::default();
261 let (tokens, _) = generate_struct(
262 table,
263 DatabaseKind::Postgres,
264 &schema,
265 &[],
266 &HashMap::new(),
267 false,
268 TimeCrate::Chrono,
269 );
270 parse_and_format(&tokens).unwrap()
271 }
272
273 fn gen_with(
274 table: &TableInfo,
275 schema: &SchemaInfo,
276 db: DatabaseKind,
277 derives: &[String],
278 overrides: &HashMap<String, String>,
279 ) -> (String, BTreeSet<String>) {
280 let (tokens, imports) = generate_struct(
281 table,
282 db,
283 schema,
284 derives,
285 overrides,
286 false,
287 TimeCrate::Chrono,
288 );
289 (parse_and_format(&tokens).unwrap(), imports)
290 }
291
292 #[test]
295 fn test_simple_table() {
296 let table = make_table(
297 "users",
298 vec![
299 make_col("id", "int4", false),
300 make_col("name", "text", false),
301 ],
302 );
303 let code = gen(&table);
304 assert!(code.contains("pub id: i32"));
305 assert!(code.contains("pub name: String"));
306 }
307
308 #[test]
309 fn test_struct_name_pascal_case_and_singular() {
310 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
311 let code = gen(&table);
312 assert!(
314 code.contains("pub struct UserRole"),
315 "expected singular PascalCase struct name, got:\n{}",
316 code
317 );
318 assert!(!code.contains("pub struct UserRoles"));
319 }
320
321 #[test]
322 fn test_struct_name_is_singular() {
323 let table = make_table("users", vec![make_col("id", "int4", false)]);
324 let code = gen(&table);
325 assert!(
326 code.contains("pub struct User"),
327 "table 'users' must produce singular 'User' struct, got:\n{}",
328 code
329 );
330 assert!(!code.contains("pub struct Users"));
331 }
332
333 #[test]
334 fn test_struct_name_already_singular_unchanged() {
335 let table = make_table("agent_connector", vec![make_col("id", "int4", false)]);
336 let code = gen(&table);
337 assert!(code.contains("pub struct AgentConnector"));
338 }
339
340 #[test]
341 fn test_struct_name_uncountable_unchanged() {
342 let table = make_table("news", vec![make_col("id", "int4", false)]);
343 let code = gen(&table);
344 assert!(code.contains("pub struct News"));
345 }
346
347 #[test]
350 fn test_nullable_column() {
351 let table = make_table("users", vec![make_col("email", "text", true)]);
352 let code = gen(&table);
353 assert!(code.contains("pub email: Option<String>"));
354 }
355
356 #[test]
357 fn test_non_nullable_column() {
358 let table = make_table("users", vec![make_col("name", "text", false)]);
359 let code = gen(&table);
360 assert!(code.contains("pub name: String"));
361 assert!(!code.contains("Option"));
362 }
363
364 #[test]
365 fn test_mix_nullable() {
366 let table = make_table(
367 "users",
368 vec![make_col("id", "int4", false), make_col("bio", "text", true)],
369 );
370 let code = gen(&table);
371 assert!(code.contains("pub id: i32"));
372 assert!(code.contains("pub bio: Option<String>"));
373 }
374
375 #[test]
378 fn test_keyword_type_renamed() {
379 let table = make_table("connector", vec![make_col("type", "text", false)]);
380 let code = gen(&table);
381 assert!(code.contains("pub connector_type: String"));
382 assert!(code.contains("sqlx(rename = \"type\")"));
383 }
384
385 #[test]
386 fn test_keyword_fn_renamed() {
387 let table = make_table("item", vec![make_col("fn", "text", false)]);
388 let code = gen(&table);
389 assert!(code.contains("pub item_fn: String"));
390 assert!(code.contains("sqlx(rename = \"fn\")"));
391 }
392
393 #[test]
394 fn test_keyword_match_renamed() {
395 let table = make_table("game", vec![make_col("match", "text", false)]);
396 let code = gen(&table);
397 assert!(code.contains("pub game_match: String"));
398 }
399
400 #[test]
401 fn test_non_keyword_no_rename() {
402 let table = make_table("users", vec![make_col("name", "text", false)]);
403 let code = gen(&table);
404 assert!(!code.contains("sqlx(rename"));
405 }
406
407 #[test]
410 fn test_camel_case_column_renamed() {
411 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
412 let code = gen(&table);
413 assert!(code.contains("pub created_at: String"));
414 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
415 }
416
417 #[test]
418 fn test_mixed_case_column_renamed() {
419 let table = make_table("users", vec![make_col("firstName", "text", false)]);
420 let code = gen(&table);
421 assert!(code.contains("pub first_name: String"));
422 assert!(code.contains("sqlx(rename = \"firstName\")"));
423 }
424
425 #[test]
426 fn test_already_snake_case_no_rename() {
427 let table = make_table("users", vec![make_col("created_at", "text", false)]);
428 let code = gen(&table);
429 assert!(code.contains("pub created_at: String"));
430 assert!(!code.contains("sqlx(rename"));
431 }
432
433 #[test]
436 fn test_default_derives() {
437 let table = make_table("users", vec![make_col("id", "int4", false)]);
438 let code = gen(&table);
439 assert!(code.contains("Debug"));
440 assert!(code.contains("Clone"));
441 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
442 }
443
444 #[test]
445 fn test_extra_derive_serialize() {
446 let table = make_table("users", vec![make_col("id", "int4", false)]);
447 let schema = SchemaInfo::default();
448 let derives = vec!["Serialize".to_string()];
449 let (code, _) = gen_with(
450 &table,
451 &schema,
452 DatabaseKind::Postgres,
453 &derives,
454 &HashMap::new(),
455 );
456 assert!(code.contains("Serialize"));
457 }
458
459 #[test]
460 fn test_extra_derives_both_serde() {
461 let table = make_table("users", vec![make_col("id", "int4", false)]);
462 let schema = SchemaInfo::default();
463 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
464 let (_, imports) = gen_with(
465 &table,
466 &schema,
467 DatabaseKind::Postgres,
468 &derives,
469 &HashMap::new(),
470 );
471 assert!(imports.iter().any(|i| i.contains("serde")));
472 }
473
474 #[test]
477 fn test_uuid_import() {
478 let table = make_table("users", vec![make_col("id", "uuid", false)]);
479 let schema = SchemaInfo::default();
480 let (_, imports) = gen_with(
481 &table,
482 &schema,
483 DatabaseKind::Postgres,
484 &[],
485 &HashMap::new(),
486 );
487 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
488 }
489
490 #[test]
491 fn test_timestamptz_import() {
492 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
493 let schema = SchemaInfo::default();
494 let (_, imports) = gen_with(
495 &table,
496 &schema,
497 DatabaseKind::Postgres,
498 &[],
499 &HashMap::new(),
500 );
501 assert!(imports.iter().any(|i| i.contains("chrono")));
502 }
503
504 #[test]
505 fn test_int4_only_serde_import() {
506 let table = make_table("users", vec![make_col("id", "int4", false)]);
507 let schema = SchemaInfo::default();
508 let (_, imports) = gen_with(
509 &table,
510 &schema,
511 DatabaseKind::Postgres,
512 &[],
513 &HashMap::new(),
514 );
515 assert_eq!(imports.len(), 2);
516 assert!(imports.iter().any(|i| i.contains("serde")));
517 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
518 }
519
520 #[test]
521 fn test_multiple_imports_collected() {
522 let table = make_table(
523 "users",
524 vec![
525 make_col("id", "uuid", false),
526 make_col("created_at", "timestamptz", false),
527 ],
528 );
529 let schema = SchemaInfo::default();
530 let (_, imports) = gen_with(
531 &table,
532 &schema,
533 DatabaseKind::Postgres,
534 &[],
535 &HashMap::new(),
536 );
537 assert!(imports.iter().any(|i| i.contains("uuid")));
538 assert!(imports.iter().any(|i| i.contains("chrono")));
539 }
540
541 #[test]
544 fn test_mysql_enum_column() {
545 let table = make_table(
546 "users",
547 vec![ColumnInfo {
548 name: "status".to_string(),
549 data_type: "enum".to_string(),
550 udt_name: "enum('active','inactive')".to_string(),
551 is_nullable: false,
552 is_primary_key: false,
553 ordinal_position: 0,
554 schema_name: "test_db".to_string(),
555 udt_schema: None,
556 column_default: None,
557 }],
558 );
559 let schema = SchemaInfo::default();
560 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
561 assert!(code.contains("UsersStatus"));
562 assert!(imports.iter().any(|i| i.contains("super::types::")));
563 }
564
565 #[test]
566 fn test_mysql_enum_nullable() {
567 let table = make_table(
568 "users",
569 vec![ColumnInfo {
570 name: "status".to_string(),
571 data_type: "enum".to_string(),
572 udt_name: "enum('a','b')".to_string(),
573 is_nullable: true,
574 is_primary_key: false,
575 ordinal_position: 0,
576 schema_name: "test_db".to_string(),
577 udt_schema: None,
578 column_default: None,
579 }],
580 );
581 let schema = SchemaInfo::default();
582 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
583 assert!(code.contains("Option<UsersStatus>"));
584 }
585
586 #[test]
589 fn test_type_override() {
590 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
591 let schema = SchemaInfo::default();
592 let mut overrides = HashMap::new();
593 overrides.insert("jsonb".to_string(), "MyJson".to_string());
594 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
595 assert!(code.contains("pub data: MyJson"));
596 }
597
598 #[test]
599 fn test_type_override_absent() {
600 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
601 let schema = SchemaInfo::default();
602 let (code, _) = gen_with(
603 &table,
604 &schema,
605 DatabaseKind::Postgres,
606 &[],
607 &HashMap::new(),
608 );
609 assert!(code.contains("Value"));
610 }
611
612 #[test]
613 fn test_type_override_nullable() {
614 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
615 let schema = SchemaInfo::default();
616 let mut overrides = HashMap::new();
617 overrides.insert("jsonb".to_string(), "MyJson".to_string());
618 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
619 assert!(code.contains("Option<MyJson>"));
620 }
621
622 #[test]
625 fn test_native_array_text_gets_sql_type_annotation() {
626 let table = make_table("posts", vec![make_col("tags", "_text", false)]);
627 let code = gen(&table);
628 assert!(code.contains("Vec<String>"));
629 assert!(code.contains("sql_type = \"text\""));
630 assert!(code.contains("is_array"));
631 }
632
633 #[test]
634 fn test_native_array_int4_gets_sql_type_annotation() {
635 let table = make_table("data", vec![make_col("values", "_int4", false)]);
636 let code = gen(&table);
637 assert!(code.contains("Vec<i32>"));
638 assert!(code.contains("sql_type = \"int4\""));
639 assert!(code.contains("is_array"));
640 }
641
642 #[test]
643 fn test_native_array_nullable_gets_sql_type_annotation() {
644 let table = make_table("posts", vec![make_col("tags", "_text", true)]);
645 let code = gen(&table);
646 assert!(code.contains("Option<Vec<String>>"));
647 assert!(code.contains("sql_type = \"text\""));
648 assert!(code.contains("is_array"));
649 }
650
651 #[test]
652 fn test_scalar_builtin_no_sql_type_annotation() {
653 let table = make_table("users", vec![make_col("name", "text", false)]);
654 let code = gen(&table);
655 assert!(code.contains("pub name: String"));
656 assert!(!code.contains("sql_type"));
657 }
658
659 #[test]
662 fn test_sanitize_replaces_dash() {
663 assert_eq!(sanitize_rust_ident("user-id"), "user_id");
664 }
665
666 #[test]
667 fn test_sanitize_replaces_space() {
668 assert_eq!(sanitize_rust_ident("created at"), "created_at");
669 }
670
671 #[test]
672 fn test_sanitize_replaces_dot() {
673 assert_eq!(sanitize_rust_ident("a.b"), "a_b");
674 }
675
676 #[test]
677 fn test_sanitize_prefixes_leading_digit() {
678 assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
679 }
680
681 #[test]
682 fn test_sanitize_empty_becomes_placeholder() {
683 assert_eq!(sanitize_rust_ident(""), "_field");
684 }
685
686 #[test]
687 fn test_sanitize_leaves_valid_ident_unchanged() {
688 assert_eq!(sanitize_rust_ident("user_id"), "user_id");
689 assert_eq!(sanitize_rust_ident("_private"), "_private");
690 }
691
692 #[test]
693 fn test_column_with_dash_generates_valid_rust() {
694 let table = make_table("users", vec![make_col("user-id", "int4", false)]);
695 let code = gen(&table);
696 assert!(
698 code.contains("pub user_id:") || code.contains("user_id:"),
699 "expected sanitized identifier, got:\n{}",
700 code
701 );
702 assert!(code.contains("sqlx(rename = \"user-id\")"));
703 }
704}