1use std::collections::{BTreeSet, HashMap};
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::codegen::{imports_for_derives, is_rust_keyword};
9use crate::introspect::{SchemaInfo, TableInfo};
10use crate::typemap;
11
12pub fn generate_struct(
13 table: &TableInfo,
14 db_kind: DatabaseKind,
15 schema_info: &SchemaInfo,
16 extra_derives: &[String],
17 type_overrides: &HashMap<String, String>,
18 is_view: bool,
19 time_crate: TimeCrate,
20) -> (TokenStream, BTreeSet<String>) {
21 let mut imports = BTreeSet::new();
22 for imp in imports_for_derives(extra_derives) {
23 imports.insert(imp);
24 }
25 let struct_name = format_ident!("{}", table.name.to_upper_camel_case());
26
27 imports.insert("use serde::{Serialize, Deserialize};".to_string());
29 imports.insert("use sqlx_gen::SqlxGen;".to_string());
30 let mut derive_tokens = vec![
31 quote! { Debug },
32 quote! { Clone },
33 quote! { PartialEq },
34 quote! { Eq },
35 quote! { Serialize },
36 quote! { Deserialize },
37 quote! { sqlx::FromRow },
38 quote! { SqlxGen },
39 ];
40 for d in extra_derives {
41 let ident = format_ident!("{}", d);
42 derive_tokens.push(quote! { #ident });
43 }
44
45 let fields: Vec<TokenStream> = table
47 .columns
48 .iter()
49 .map(|col| {
50 let rust_type =
51 resolve_column_type(col, db_kind, table, schema_info, type_overrides, time_crate);
52 if let Some(imp) = &rust_type.needs_import {
53 imports.insert(imp.clone());
54 }
55
56 let field_name_snake = sanitize_rust_ident(&col.name.to_snake_case());
57 let (effective_name, needs_rename) = if is_rust_keyword(&field_name_snake) {
60 let prefixed = format!("{}_{}", table.name.to_snake_case(), field_name_snake);
61 (prefixed, true)
62 } else {
63 let changed = field_name_snake != col.name;
64 (field_name_snake, changed)
65 };
66
67 let field_ident = format_ident!("{}", effective_name);
68 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
69 let fallback = format_ident!("String");
70 quote! { #fallback }
71 });
72
73 let rename = if needs_rename {
74 let original = &col.name;
75 quote! { #[sqlx(rename = #original)] }
76 } else {
77 quote! {}
78 };
79
80 let (sql_type, is_sql_array) = detect_custom_sql_type(&col.udt_name, schema_info);
82 let has_pk = col.is_primary_key;
83 let has_sql_type = sql_type.is_some();
84 let has_default = col.column_default.is_some();
85
86 let sqlx_gen_attr = if has_pk || has_sql_type || has_default {
87 let pk_part = if has_pk {
88 quote! { primary_key, }
89 } else {
90 quote! {}
91 };
92 let sql_type_part = match &sql_type {
93 Some(t) => quote! { sql_type = #t, },
94 None => quote! {},
95 };
96 let array_part = if is_sql_array {
97 quote! { is_array, }
98 } else {
99 quote! {}
100 };
101 let default_part = match &col.column_default {
102 Some(d) => quote! { column_default = #d, },
103 None => quote! {},
104 };
105 quote! { #[sqlx_gen(#pk_part #sql_type_part #array_part #default_part)] }
106 } else {
107 quote! {}
108 };
109
110 quote! {
111 #sqlx_gen_attr
112 #rename
113 pub #field_ident: #type_tokens,
114 }
115 })
116 .collect();
117
118 let table_name_str = &table.name;
119 let schema_name_str = &table.schema_name;
120 let kind_str = if is_view { "view" } else { "table" };
121
122 let tokens = quote! {
123 #[derive(#(#derive_tokens),*)]
124 #[sqlx_gen(kind = #kind_str, schema = #schema_name_str, table = #table_name_str)]
125 pub struct #struct_name {
126 #(#fields)*
127 }
128 };
129
130 (tokens, imports)
131}
132
133pub(crate) fn sanitize_rust_ident(name: &str) -> String {
142 if name.is_empty() {
143 return "_field".to_string();
144 }
145 let mut out: String = name
146 .chars()
147 .map(|c| {
148 if c.is_ascii_alphanumeric() || c == '_' {
149 c
150 } else {
151 '_'
152 }
153 })
154 .collect();
155 if out.starts_with(|c: char| c.is_ascii_digit()) {
156 out.insert(0, '_');
157 }
158 out
159}
160
161fn detect_custom_sql_type(udt_name: &str, schema_info: &SchemaInfo) -> (Option<String>, bool) {
166 let (base_name, is_array) = match udt_name.strip_prefix('_') {
167 Some(inner) => (inner, true),
168 None => (udt_name, false),
169 };
170
171 if schema_info.enums.iter().any(|e| e.name == base_name) {
173 return (Some(base_name.to_string()), is_array);
174 }
175
176 if schema_info
178 .composite_types
179 .iter()
180 .any(|c| c.name == base_name)
181 {
182 return (Some(base_name.to_string()), is_array);
183 }
184
185 let is_domain = schema_info.domains.iter().any(|d| d.name == base_name);
189 if !is_domain && !typemap::postgres::is_builtin(base_name) {
190 return (Some(base_name.to_string()), is_array);
191 }
192
193 if is_array {
197 return (Some(base_name.to_string()), true);
198 }
199
200 (None, false)
201}
202
203fn resolve_column_type(
204 col: &crate::introspect::ColumnInfo,
205 db_kind: DatabaseKind,
206 table: &TableInfo,
207 schema_info: &SchemaInfo,
208 type_overrides: &HashMap<String, String>,
209 time_crate: TimeCrate,
210) -> typemap::RustType {
211 if db_kind == DatabaseKind::Mysql && col.udt_name.starts_with("enum(") {
213 let enum_type_name = typemap::mysql::resolve_enum_type(&table.name, &col.name);
214 let rt = typemap::RustType::with_import(
215 &enum_type_name,
216 &format!("use super::types::{};", enum_type_name),
217 );
218 return if col.is_nullable {
219 rt.wrap_option()
220 } else {
221 rt
222 };
223 }
224
225 typemap::map_column(col, db_kind, schema_info, type_overrides, time_crate)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::codegen::parse_and_format;
232 use crate::introspect::ColumnInfo;
233
234 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
235 TableInfo {
236 schema_name: "public".to_string(),
237 name: name.to_string(),
238 columns,
239 }
240 }
241
242 fn make_col(name: &str, udt_name: &str, nullable: bool) -> ColumnInfo {
243 ColumnInfo {
244 name: name.to_string(),
245 data_type: udt_name.to_string(),
246 udt_name: udt_name.to_string(),
247 is_nullable: nullable,
248 is_primary_key: false,
249 ordinal_position: 0,
250 schema_name: "public".to_string(),
251 udt_schema: None,
252 column_default: None,
253 }
254 }
255
256 fn gen(table: &TableInfo) -> String {
257 let schema = SchemaInfo::default();
258 let (tokens, _) = generate_struct(
259 table,
260 DatabaseKind::Postgres,
261 &schema,
262 &[],
263 &HashMap::new(),
264 false,
265 TimeCrate::Chrono,
266 );
267 parse_and_format(&tokens).unwrap()
268 }
269
270 fn gen_with(
271 table: &TableInfo,
272 schema: &SchemaInfo,
273 db: DatabaseKind,
274 derives: &[String],
275 overrides: &HashMap<String, String>,
276 ) -> (String, BTreeSet<String>) {
277 let (tokens, imports) = generate_struct(
278 table,
279 db,
280 schema,
281 derives,
282 overrides,
283 false,
284 TimeCrate::Chrono,
285 );
286 (parse_and_format(&tokens).unwrap(), imports)
287 }
288
289 #[test]
292 fn test_simple_table() {
293 let table = make_table(
294 "users",
295 vec![
296 make_col("id", "int4", false),
297 make_col("name", "text", false),
298 ],
299 );
300 let code = gen(&table);
301 assert!(code.contains("pub id: i32"));
302 assert!(code.contains("pub name: String"));
303 }
304
305 #[test]
306 fn test_struct_name_pascal_case() {
307 let table = make_table("user_roles", vec![make_col("id", "int4", false)]);
308 let code = gen(&table);
309 assert!(code.contains("pub struct UserRoles"));
310 }
311
312 #[test]
313 fn test_struct_name_simple() {
314 let table = make_table("users", vec![make_col("id", "int4", false)]);
315 let code = gen(&table);
316 assert!(code.contains("pub struct Users"));
317 }
318
319 #[test]
322 fn test_nullable_column() {
323 let table = make_table("users", vec![make_col("email", "text", true)]);
324 let code = gen(&table);
325 assert!(code.contains("pub email: Option<String>"));
326 }
327
328 #[test]
329 fn test_non_nullable_column() {
330 let table = make_table("users", vec![make_col("name", "text", false)]);
331 let code = gen(&table);
332 assert!(code.contains("pub name: String"));
333 assert!(!code.contains("Option"));
334 }
335
336 #[test]
337 fn test_mix_nullable() {
338 let table = make_table(
339 "users",
340 vec![make_col("id", "int4", false), make_col("bio", "text", true)],
341 );
342 let code = gen(&table);
343 assert!(code.contains("pub id: i32"));
344 assert!(code.contains("pub bio: Option<String>"));
345 }
346
347 #[test]
350 fn test_keyword_type_renamed() {
351 let table = make_table("connector", vec![make_col("type", "text", false)]);
352 let code = gen(&table);
353 assert!(code.contains("pub connector_type: String"));
354 assert!(code.contains("sqlx(rename = \"type\")"));
355 }
356
357 #[test]
358 fn test_keyword_fn_renamed() {
359 let table = make_table("item", vec![make_col("fn", "text", false)]);
360 let code = gen(&table);
361 assert!(code.contains("pub item_fn: String"));
362 assert!(code.contains("sqlx(rename = \"fn\")"));
363 }
364
365 #[test]
366 fn test_keyword_match_renamed() {
367 let table = make_table("game", vec![make_col("match", "text", false)]);
368 let code = gen(&table);
369 assert!(code.contains("pub game_match: String"));
370 }
371
372 #[test]
373 fn test_non_keyword_no_rename() {
374 let table = make_table("users", vec![make_col("name", "text", false)]);
375 let code = gen(&table);
376 assert!(!code.contains("sqlx(rename"));
377 }
378
379 #[test]
382 fn test_camel_case_column_renamed() {
383 let table = make_table("users", vec![make_col("CreatedAt", "text", false)]);
384 let code = gen(&table);
385 assert!(code.contains("pub created_at: String"));
386 assert!(code.contains("sqlx(rename = \"CreatedAt\")"));
387 }
388
389 #[test]
390 fn test_mixed_case_column_renamed() {
391 let table = make_table("users", vec![make_col("firstName", "text", false)]);
392 let code = gen(&table);
393 assert!(code.contains("pub first_name: String"));
394 assert!(code.contains("sqlx(rename = \"firstName\")"));
395 }
396
397 #[test]
398 fn test_already_snake_case_no_rename() {
399 let table = make_table("users", vec![make_col("created_at", "text", false)]);
400 let code = gen(&table);
401 assert!(code.contains("pub created_at: String"));
402 assert!(!code.contains("sqlx(rename"));
403 }
404
405 #[test]
408 fn test_default_derives() {
409 let table = make_table("users", vec![make_col("id", "int4", false)]);
410 let code = gen(&table);
411 assert!(code.contains("Debug"));
412 assert!(code.contains("Clone"));
413 assert!(code.contains("sqlx::FromRow") || code.contains("sqlx :: FromRow"));
414 }
415
416 #[test]
417 fn test_extra_derive_serialize() {
418 let table = make_table("users", vec![make_col("id", "int4", false)]);
419 let schema = SchemaInfo::default();
420 let derives = vec!["Serialize".to_string()];
421 let (code, _) = gen_with(
422 &table,
423 &schema,
424 DatabaseKind::Postgres,
425 &derives,
426 &HashMap::new(),
427 );
428 assert!(code.contains("Serialize"));
429 }
430
431 #[test]
432 fn test_extra_derives_both_serde() {
433 let table = make_table("users", vec![make_col("id", "int4", false)]);
434 let schema = SchemaInfo::default();
435 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
436 let (_, imports) = gen_with(
437 &table,
438 &schema,
439 DatabaseKind::Postgres,
440 &derives,
441 &HashMap::new(),
442 );
443 assert!(imports.iter().any(|i| i.contains("serde")));
444 }
445
446 #[test]
449 fn test_uuid_import() {
450 let table = make_table("users", vec![make_col("id", "uuid", false)]);
451 let schema = SchemaInfo::default();
452 let (_, imports) = gen_with(
453 &table,
454 &schema,
455 DatabaseKind::Postgres,
456 &[],
457 &HashMap::new(),
458 );
459 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
460 }
461
462 #[test]
463 fn test_timestamptz_import() {
464 let table = make_table("users", vec![make_col("created_at", "timestamptz", false)]);
465 let schema = SchemaInfo::default();
466 let (_, imports) = gen_with(
467 &table,
468 &schema,
469 DatabaseKind::Postgres,
470 &[],
471 &HashMap::new(),
472 );
473 assert!(imports.iter().any(|i| i.contains("chrono")));
474 }
475
476 #[test]
477 fn test_int4_only_serde_import() {
478 let table = make_table("users", vec![make_col("id", "int4", false)]);
479 let schema = SchemaInfo::default();
480 let (_, imports) = gen_with(
481 &table,
482 &schema,
483 DatabaseKind::Postgres,
484 &[],
485 &HashMap::new(),
486 );
487 assert_eq!(imports.len(), 2);
488 assert!(imports.iter().any(|i| i.contains("serde")));
489 assert!(imports.iter().any(|i| i.contains("sqlx_gen::SqlxGen")));
490 }
491
492 #[test]
493 fn test_multiple_imports_collected() {
494 let table = make_table(
495 "users",
496 vec![
497 make_col("id", "uuid", false),
498 make_col("created_at", "timestamptz", false),
499 ],
500 );
501 let schema = SchemaInfo::default();
502 let (_, imports) = gen_with(
503 &table,
504 &schema,
505 DatabaseKind::Postgres,
506 &[],
507 &HashMap::new(),
508 );
509 assert!(imports.iter().any(|i| i.contains("uuid")));
510 assert!(imports.iter().any(|i| i.contains("chrono")));
511 }
512
513 #[test]
516 fn test_mysql_enum_column() {
517 let table = make_table(
518 "users",
519 vec![ColumnInfo {
520 name: "status".to_string(),
521 data_type: "enum".to_string(),
522 udt_name: "enum('active','inactive')".to_string(),
523 is_nullable: false,
524 is_primary_key: false,
525 ordinal_position: 0,
526 schema_name: "test_db".to_string(),
527 udt_schema: None,
528 column_default: None,
529 }],
530 );
531 let schema = SchemaInfo::default();
532 let (code, imports) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
533 assert!(code.contains("UsersStatus"));
534 assert!(imports.iter().any(|i| i.contains("super::types::")));
535 }
536
537 #[test]
538 fn test_mysql_enum_nullable() {
539 let table = make_table(
540 "users",
541 vec![ColumnInfo {
542 name: "status".to_string(),
543 data_type: "enum".to_string(),
544 udt_name: "enum('a','b')".to_string(),
545 is_nullable: true,
546 is_primary_key: false,
547 ordinal_position: 0,
548 schema_name: "test_db".to_string(),
549 udt_schema: None,
550 column_default: None,
551 }],
552 );
553 let schema = SchemaInfo::default();
554 let (code, _) = gen_with(&table, &schema, DatabaseKind::Mysql, &[], &HashMap::new());
555 assert!(code.contains("Option<UsersStatus>"));
556 }
557
558 #[test]
561 fn test_type_override() {
562 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
563 let schema = SchemaInfo::default();
564 let mut overrides = HashMap::new();
565 overrides.insert("jsonb".to_string(), "MyJson".to_string());
566 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
567 assert!(code.contains("pub data: MyJson"));
568 }
569
570 #[test]
571 fn test_type_override_absent() {
572 let table = make_table("users", vec![make_col("data", "jsonb", false)]);
573 let schema = SchemaInfo::default();
574 let (code, _) = gen_with(
575 &table,
576 &schema,
577 DatabaseKind::Postgres,
578 &[],
579 &HashMap::new(),
580 );
581 assert!(code.contains("Value"));
582 }
583
584 #[test]
585 fn test_type_override_nullable() {
586 let table = make_table("users", vec![make_col("data", "jsonb", true)]);
587 let schema = SchemaInfo::default();
588 let mut overrides = HashMap::new();
589 overrides.insert("jsonb".to_string(), "MyJson".to_string());
590 let (code, _) = gen_with(&table, &schema, DatabaseKind::Postgres, &[], &overrides);
591 assert!(code.contains("Option<MyJson>"));
592 }
593
594 #[test]
597 fn test_native_array_text_gets_sql_type_annotation() {
598 let table = make_table("posts", vec![make_col("tags", "_text", false)]);
599 let code = gen(&table);
600 assert!(code.contains("Vec<String>"));
601 assert!(code.contains("sql_type = \"text\""));
602 assert!(code.contains("is_array"));
603 }
604
605 #[test]
606 fn test_native_array_int4_gets_sql_type_annotation() {
607 let table = make_table("data", vec![make_col("values", "_int4", false)]);
608 let code = gen(&table);
609 assert!(code.contains("Vec<i32>"));
610 assert!(code.contains("sql_type = \"int4\""));
611 assert!(code.contains("is_array"));
612 }
613
614 #[test]
615 fn test_native_array_nullable_gets_sql_type_annotation() {
616 let table = make_table("posts", vec![make_col("tags", "_text", true)]);
617 let code = gen(&table);
618 assert!(code.contains("Option<Vec<String>>"));
619 assert!(code.contains("sql_type = \"text\""));
620 assert!(code.contains("is_array"));
621 }
622
623 #[test]
624 fn test_scalar_builtin_no_sql_type_annotation() {
625 let table = make_table("users", vec![make_col("name", "text", false)]);
626 let code = gen(&table);
627 assert!(code.contains("pub name: String"));
628 assert!(!code.contains("sql_type"));
629 }
630
631 #[test]
634 fn test_sanitize_replaces_dash() {
635 assert_eq!(sanitize_rust_ident("user-id"), "user_id");
636 }
637
638 #[test]
639 fn test_sanitize_replaces_space() {
640 assert_eq!(sanitize_rust_ident("created at"), "created_at");
641 }
642
643 #[test]
644 fn test_sanitize_replaces_dot() {
645 assert_eq!(sanitize_rust_ident("a.b"), "a_b");
646 }
647
648 #[test]
649 fn test_sanitize_prefixes_leading_digit() {
650 assert_eq!(sanitize_rust_ident("123abc"), "_123abc");
651 }
652
653 #[test]
654 fn test_sanitize_empty_becomes_placeholder() {
655 assert_eq!(sanitize_rust_ident(""), "_field");
656 }
657
658 #[test]
659 fn test_sanitize_leaves_valid_ident_unchanged() {
660 assert_eq!(sanitize_rust_ident("user_id"), "user_id");
661 assert_eq!(sanitize_rust_ident("_private"), "_private");
662 }
663
664 #[test]
665 fn test_column_with_dash_generates_valid_rust() {
666 let table = make_table("users", vec![make_col("user-id", "int4", false)]);
667 let code = gen(&table);
668 assert!(
670 code.contains("pub user_id:") || code.contains("user_id:"),
671 "expected sanitized identifier, got:\n{}",
672 code
673 );
674 assert!(code.contains("sqlx(rename = \"user-id\")"));
675 }
676}