1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15const RUST_KEYWORDS: &[&str] = &[
17 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20 "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24pub fn is_rust_keyword(name: &str) -> bool {
26 RUST_KEYWORDS.contains(&name)
27}
28
29pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31 let mut imports = Vec::new();
32 let has = |name: &str| extra_derives.iter().any(|d| d == name);
33 if has("Serialize") || has("Deserialize") {
34 let mut parts = Vec::new();
35 if has("Serialize") {
36 parts.push("Serialize");
37 }
38 if has("Deserialize") {
39 parts.push("Deserialize");
40 }
41 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42 }
43 imports
44}
45
46pub fn normalize_module_name(name: &str) -> String {
49 let mut result = String::with_capacity(name.len());
50 let mut prev_underscore = false;
51 for c in name.chars() {
52 if c == '_' {
53 if !prev_underscore {
54 result.push(c);
55 }
56 prev_underscore = true;
57 } else {
58 prev_underscore = false;
59 result.push(c);
60 }
61 }
62 result
63}
64
65const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68pub fn is_default_schema(schema: &str) -> bool {
70 DEFAULT_SCHEMAS.contains(&schema)
71}
72
73pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76 if name_collides && !is_default_schema(schema_name) {
77 normalize_module_name(&format!("{}_{}", schema_name, table_name))
78 } else {
79 normalize_module_name(table_name)
80 }
81}
82
83fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86 for t in &schema_info.tables {
87 seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88 }
89 for v in &schema_info.views {
90 seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91 }
92 seen.into_iter()
93 .filter(|(_, schemas)| schemas.len() > 1)
94 .map(|(name, _)| name)
95 .collect()
96}
97
98#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101 pub filename: String,
102 pub origin: Option<String>,
104 pub code: String,
105}
106
107pub fn generate(
109 schema_info: &SchemaInfo,
110 db_kind: DatabaseKind,
111 extra_derives: &[String],
112 type_overrides: &HashMap<String, String>,
113 single_file: bool,
114) -> Vec<GeneratedFile> {
115 let mut files = Vec::new();
116
117 let colliding_names = find_colliding_names(schema_info);
119
120 for table in &schema_info.tables {
122 let (tokens, imports) =
123 struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124 let imports = filter_imports(&imports, single_file);
125 let code = format_tokens_with_imports(&tokens, &imports);
126 let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127 let origin = format!("Table: {}.{}", table.schema_name, table.name);
128 files.push(GeneratedFile {
129 filename: format!("{}.rs", module_name),
130 origin: Some(origin),
131 code,
132 });
133 }
134
135 for view in &schema_info.views {
137 let (tokens, imports) =
138 struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
139 let imports = filter_imports(&imports, single_file);
140 let code = format_tokens_with_imports(&tokens, &imports);
141 let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
142 let origin = format!("View: {}.{}", view.schema_name, view.name);
143 files.push(GeneratedFile {
144 filename: format!("{}.rs", module_name),
145 origin: Some(origin),
146 code,
147 });
148 }
149
150 let mut types_blocks: Vec<String> = Vec::new();
153 let mut types_imports = BTreeSet::new();
154
155 let enum_defaults = extract_enum_defaults(schema_info);
157 for enum_info in &schema_info.enums {
158 let mut enriched = enum_info.clone();
159 if enriched.default_variant.is_none() {
160 if let Some(default) = enum_defaults.get(&enum_info.name) {
161 enriched.default_variant = Some(default.clone());
162 }
163 }
164 let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
165 types_blocks.push(format_tokens(&tokens));
166 types_imports.extend(imports);
167 }
168
169 for composite in &schema_info.composite_types {
170 let (tokens, imports) = composite_gen::generate_composite(
171 composite,
172 db_kind,
173 schema_info,
174 extra_derives,
175 type_overrides,
176 );
177 types_blocks.push(format_tokens(&tokens));
178 types_imports.extend(imports);
179 }
180
181 for domain in &schema_info.domains {
182 let (tokens, imports) =
183 domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
184 types_blocks.push(format_tokens(&tokens));
185 types_imports.extend(imports);
186 }
187
188 if !types_blocks.is_empty() {
189 let import_lines: String = types_imports
190 .iter()
191 .map(|i| format!("{}\n", i))
192 .collect();
193 let body = types_blocks.join("\n");
194 let code = if import_lines.is_empty() {
195 body
196 } else {
197 format!("{}\n\n{}", import_lines.trim_end(), body)
198 };
199 files.push(GeneratedFile {
200 filename: "types.rs".to_string(),
201 origin: None,
202 code,
203 });
204 }
205
206 files
207}
208
209fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
212 let mut defaults: HashMap<String, String> = HashMap::new();
213
214 let all_columns = schema_info
215 .tables
216 .iter()
217 .chain(schema_info.views.iter())
218 .flat_map(|t| t.columns.iter());
219
220 for col in all_columns {
221 let default_expr = match &col.column_default {
222 Some(d) => d,
223 None => continue,
224 };
225
226 let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
228
229 let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
231 if enum_match.is_none() {
232 continue;
233 }
234
235 if let Some(variant) = parse_pg_enum_default(default_expr) {
237 defaults.entry(base_udt.to_string()).or_insert(variant);
238 }
239 }
240
241 defaults
242}
243
244fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
247 let stripped = default_expr.trim();
249 if stripped.starts_with('\'') {
250 if let Some(end_quote) = stripped[1..].find('\'') {
251 let value = &stripped[1..1 + end_quote];
252 let rest = &stripped[2 + end_quote..];
254 if rest.starts_with("::") {
255 return Some(value.to_string());
256 }
257 }
258 }
259 None
260}
261
262fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
264 if single_file {
265 imports
266 .iter()
267 .filter(|i| !i.contains("super::types::"))
268 .cloned()
269 .collect()
270 } else {
271 imports.clone()
272 }
273}
274
275pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
277 let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
278 log::error!("Failed to parse generated code: {}", e);
279 log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens);
280 std::process::exit(1);
281 });
282 let raw = prettyplease::unparse(&file);
283 add_blank_lines_between_items(&raw)
284}
285
286pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
288 parse_and_format(tokens)
289}
290
291pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
292 let formatted = parse_and_format(tokens);
293
294 let used_imports: Vec<&String> = imports
295 .iter()
296 .filter(|imp| is_import_used(imp, &formatted))
297 .collect();
298
299 if used_imports.is_empty() {
300 formatted
301 } else {
302 let import_lines: String = used_imports
303 .iter()
304 .map(|i| format!("{}\n", i))
305 .collect();
306 format!("{}\n\n{}", import_lines.trim_end(), formatted)
307 }
308}
309
310fn is_import_used(import: &str, code: &str) -> bool {
313 let trimmed = import.trim().trim_end_matches(';');
317 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
318
319 if path.ends_with("::*") {
320 return true;
321 }
322
323 if let Some(start) = path.find('{') {
325 if let Some(end) = path.find('}') {
326 let names = &path[start + 1..end];
327 return names
328 .split(',')
329 .map(|n| n.trim())
330 .filter(|n| !n.is_empty())
331 .any(|name| code.contains(name));
332 }
333 }
334
335 if let Some(name) = path.rsplit("::").next() {
337 return code.contains(name);
338 }
339
340 true
341}
342
343fn add_blank_lines_between_items(code: &str) -> String {
348 let lines: Vec<&str> = code.lines().collect();
349 let mut result = Vec::with_capacity(lines.len());
350
351 for (i, line) in lines.iter().enumerate() {
352 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
355 let prev = lines[i - 1].trim();
356 if prev.ends_with(',') {
357 result.push("");
358 }
359 }
360
361 if i > 0 {
364 let trimmed = line.trim();
365 let prev = lines[i - 1].trim();
366 if prev == "}"
367 && (trimmed.starts_with("pub struct")
368 || trimmed.starts_with("impl ")
369 || trimmed.starts_with("#[derive")
370 || trimmed.starts_with("pub async fn")
371 || trimmed.starts_with("pub fn"))
372 {
373 result.push("");
374 }
375 }
376
377 if i > 0 {
381 let trimmed = line.trim();
382 let prev = lines[i - 1].trim();
383 let prev_is_await_end = prev.ends_with(".await?;")
384 || prev.ends_with(".await?")
385 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
386 if prev_is_await_end
387 && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
388 {
389 result.push("");
390 }
391 if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
393 && prev.starts_with("let ") && !prev.contains("sqlx::")
394 {
395 result.push("");
396 }
397 }
398
399 result.push(line);
400 }
401
402 result.join("\n")
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::introspect::{
409 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
410 };
411 use std::collections::HashMap;
412
413 #[test]
416 fn test_keyword_type() {
417 assert!(is_rust_keyword("type"));
418 }
419
420 #[test]
421 fn test_keyword_fn() {
422 assert!(is_rust_keyword("fn"));
423 }
424
425 #[test]
426 fn test_keyword_let() {
427 assert!(is_rust_keyword("let"));
428 }
429
430 #[test]
431 fn test_keyword_match() {
432 assert!(is_rust_keyword("match"));
433 }
434
435 #[test]
436 fn test_keyword_async() {
437 assert!(is_rust_keyword("async"));
438 }
439
440 #[test]
441 fn test_keyword_await() {
442 assert!(is_rust_keyword("await"));
443 }
444
445 #[test]
446 fn test_keyword_yield() {
447 assert!(is_rust_keyword("yield"));
448 }
449
450 #[test]
451 fn test_keyword_abstract() {
452 assert!(is_rust_keyword("abstract"));
453 }
454
455 #[test]
456 fn test_keyword_try() {
457 assert!(is_rust_keyword("try"));
458 }
459
460 #[test]
461 fn test_not_keyword_name() {
462 assert!(!is_rust_keyword("name"));
463 }
464
465 #[test]
466 fn test_not_keyword_id() {
467 assert!(!is_rust_keyword("id"));
468 }
469
470 #[test]
471 fn test_not_keyword_uppercase_type() {
472 assert!(!is_rust_keyword("Type"));
473 }
474
475 #[test]
478 fn test_normalize_no_underscores() {
479 assert_eq!(normalize_module_name("users"), "users");
480 }
481
482 #[test]
483 fn test_normalize_single_underscore() {
484 assert_eq!(normalize_module_name("user_roles"), "user_roles");
485 }
486
487 #[test]
488 fn test_normalize_double_underscore() {
489 assert_eq!(normalize_module_name("user__roles"), "user_roles");
490 }
491
492 #[test]
493 fn test_normalize_triple_underscore() {
494 assert_eq!(normalize_module_name("a___b"), "a_b");
495 }
496
497 #[test]
498 fn test_normalize_leading_underscore() {
499 assert_eq!(normalize_module_name("_private"), "_private");
500 }
501
502 #[test]
503 fn test_normalize_trailing_underscore() {
504 assert_eq!(normalize_module_name("name_"), "name_");
505 }
506
507 #[test]
508 fn test_normalize_double_leading() {
509 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
510 }
511
512 #[test]
513 fn test_normalize_multiple_groups() {
514 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
515 }
516
517 #[test]
520 fn test_build_no_collision_no_prefix() {
521 assert_eq!(build_module_name("public", "users", false), "users");
522 }
523
524 #[test]
525 fn test_build_no_collision_non_default_no_prefix() {
526 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
527 }
528
529 #[test]
530 fn test_build_collision_prefixed() {
531 assert_eq!(build_module_name("billing", "users", true), "billing_users");
532 }
533
534 #[test]
535 fn test_build_collision_default_schema_no_prefix() {
536 assert_eq!(build_module_name("public", "users", true), "users");
537 }
538
539 #[test]
540 fn test_build_collision_normalizes_double_underscore() {
541 assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
542 }
543
544 #[test]
547 fn test_default_schema_public() {
548 assert!(is_default_schema("public"));
549 }
550
551 #[test]
552 fn test_default_schema_main() {
553 assert!(is_default_schema("main"));
554 }
555
556 #[test]
557 fn test_non_default_schema() {
558 assert!(!is_default_schema("billing"));
559 }
560
561 #[test]
564 fn test_imports_empty() {
565 let result = imports_for_derives(&[]);
566 assert!(result.is_empty());
567 }
568
569 #[test]
570 fn test_imports_serialize_only() {
571 let derives = vec!["Serialize".to_string()];
572 let result = imports_for_derives(&derives);
573 assert_eq!(result, vec!["use serde::{Serialize};"]);
574 }
575
576 #[test]
577 fn test_imports_deserialize_only() {
578 let derives = vec!["Deserialize".to_string()];
579 let result = imports_for_derives(&derives);
580 assert_eq!(result, vec!["use serde::{Deserialize};"]);
581 }
582
583 #[test]
584 fn test_imports_both_serde() {
585 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
586 let result = imports_for_derives(&derives);
587 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
588 }
589
590 #[test]
591 fn test_imports_non_serde() {
592 let derives = vec!["Hash".to_string()];
593 let result = imports_for_derives(&derives);
594 assert!(result.is_empty());
595 }
596
597 #[test]
598 fn test_imports_non_serde_multiple() {
599 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
600 let result = imports_for_derives(&derives);
601 assert!(result.is_empty());
602 }
603
604 #[test]
605 fn test_imports_mixed_serde_and_others() {
606 let derives = vec![
607 "Serialize".to_string(),
608 "Hash".to_string(),
609 "Deserialize".to_string(),
610 ];
611 let result = imports_for_derives(&derives);
612 assert_eq!(result.len(), 1);
613 assert!(result[0].contains("Serialize"));
614 assert!(result[0].contains("Deserialize"));
615 }
616
617 #[test]
620 fn test_blank_lines_between_renamed_variants() {
621 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
622 let result = add_blank_lines_between_items(input);
623 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
624 }
625
626 #[test]
627 fn test_no_blank_line_for_first_variant() {
628 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
629 let result = add_blank_lines_between_items(input);
630 assert!(!result.contains("{\n\n"));
632 }
633
634 #[test]
635 fn test_no_change_without_rename() {
636 let input = "pub enum Foo {\n A,\n B,\n}";
637 let result = add_blank_lines_between_items(input);
638 assert_eq!(result, input);
639 }
640
641 #[test]
642 fn test_no_change_for_struct() {
643 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
644 let result = add_blank_lines_between_items(input);
645 assert_eq!(result, input);
646 }
647
648 #[test]
651 fn test_filter_single_file_strips_super_types() {
652 let mut imports = BTreeSet::new();
653 imports.insert("use super::types::Foo;".to_string());
654 imports.insert("use chrono::NaiveDateTime;".to_string());
655 let result = filter_imports(&imports, true);
656 assert!(!result.contains("use super::types::Foo;"));
657 assert!(result.contains("use chrono::NaiveDateTime;"));
658 }
659
660 #[test]
661 fn test_filter_single_file_keeps_other_imports() {
662 let mut imports = BTreeSet::new();
663 imports.insert("use chrono::NaiveDateTime;".to_string());
664 let result = filter_imports(&imports, true);
665 assert!(result.contains("use chrono::NaiveDateTime;"));
666 }
667
668 #[test]
669 fn test_filter_multi_file_keeps_all() {
670 let mut imports = BTreeSet::new();
671 imports.insert("use super::types::Foo;".to_string());
672 imports.insert("use chrono::NaiveDateTime;".to_string());
673 let result = filter_imports(&imports, false);
674 assert_eq!(result.len(), 2);
675 }
676
677 #[test]
678 fn test_filter_empty_set() {
679 let imports = BTreeSet::new();
680 let result = filter_imports(&imports, true);
681 assert!(result.is_empty());
682 }
683
684 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
687 TableInfo {
688 schema_name: "public".to_string(),
689 name: name.to_string(),
690 columns,
691 }
692 }
693
694 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
695 ColumnInfo {
696 name: name.to_string(),
697 data_type: udt_name.to_string(),
698 udt_name: udt_name.to_string(),
699 is_nullable: false,
700 is_primary_key: false,
701 ordinal_position: 0,
702 schema_name: "public".to_string(),
703 column_default: None,
704 }
705 }
706
707 #[test]
708 fn test_generate_empty_schema() {
709 let schema = SchemaInfo::default();
710 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
711 assert!(files.is_empty());
712 }
713
714 #[test]
715 fn test_generate_one_table() {
716 let schema = SchemaInfo {
717 tables: vec![make_table("users", vec![make_col("id", "int4")])],
718 ..Default::default()
719 };
720 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
721 assert_eq!(files.len(), 1);
722 assert_eq!(files[0].filename, "users.rs");
723 }
724
725 #[test]
726 fn test_generate_two_tables() {
727 let schema = SchemaInfo {
728 tables: vec![
729 make_table("users", vec![make_col("id", "int4")]),
730 make_table("posts", vec![make_col("id", "int4")]),
731 ],
732 ..Default::default()
733 };
734 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
735 assert_eq!(files.len(), 2);
736 }
737
738 #[test]
739 fn test_generate_enum_creates_types_file() {
740 let schema = SchemaInfo {
741 enums: vec![EnumInfo {
742 schema_name: "public".to_string(),
743 name: "status".to_string(),
744 variants: vec!["active".to_string(), "inactive".to_string()],
745 default_variant: None,
746 }],
747 ..Default::default()
748 };
749 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
750 assert_eq!(files.len(), 1);
751 assert_eq!(files[0].filename, "types.rs");
752 }
753
754 #[test]
755 fn test_generate_enums_composites_domains_single_types_file() {
756 let schema = SchemaInfo {
757 enums: vec![EnumInfo {
758 schema_name: "public".to_string(),
759 name: "status".to_string(),
760 variants: vec!["active".to_string()],
761 default_variant: None,
762 }],
763 composite_types: vec![CompositeTypeInfo {
764 schema_name: "public".to_string(),
765 name: "address".to_string(),
766 fields: vec![make_col("street", "text")],
767 }],
768 domains: vec![DomainInfo {
769 schema_name: "public".to_string(),
770 name: "email".to_string(),
771 base_type: "text".to_string(),
772 }],
773 ..Default::default()
774 };
775 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
776 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
778 assert_eq!(types_files.len(), 1);
779 }
780
781 #[test]
782 fn test_generate_tables_and_enums() {
783 let schema = SchemaInfo {
784 tables: vec![make_table("users", vec![make_col("id", "int4")])],
785 enums: vec![EnumInfo {
786 schema_name: "public".to_string(),
787 name: "status".to_string(),
788 variants: vec!["active".to_string()],
789 default_variant: None,
790 }],
791 ..Default::default()
792 };
793 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
794 assert_eq!(files.len(), 2); }
796
797 #[test]
798 fn test_generate_filename_normalized() {
799 let schema = SchemaInfo {
800 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
801 ..Default::default()
802 };
803 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
804 assert_eq!(files[0].filename, "user_data.rs");
805 }
806
807 #[test]
808 fn test_generate_origin_correct() {
809 let schema = SchemaInfo {
810 tables: vec![make_table("users", vec![make_col("id", "int4")])],
811 ..Default::default()
812 };
813 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
814 assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
815 }
816
817 #[test]
818 fn test_generate_types_no_origin() {
819 let schema = SchemaInfo {
820 enums: vec![EnumInfo {
821 schema_name: "public".to_string(),
822 name: "status".to_string(),
823 variants: vec!["active".to_string()],
824 default_variant: None,
825 }],
826 ..Default::default()
827 };
828 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
829 assert_eq!(files[0].origin, None);
830 }
831
832 #[test]
833 fn test_generate_single_file_filters_super_types_imports() {
834 let schema = SchemaInfo {
835 tables: vec![make_table("users", vec![make_col("id", "int4")])],
836 enums: vec![EnumInfo {
837 schema_name: "public".to_string(),
838 name: "status".to_string(),
839 variants: vec!["active".to_string()],
840 default_variant: None,
841 }],
842 ..Default::default()
843 };
844 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
845 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
847 assert!(!struct_file.code.contains("super::types::"));
848 }
849
850 #[test]
851 fn test_generate_multi_file_keeps_super_types_imports() {
852 let schema = SchemaInfo {
854 tables: vec![make_table("users", vec![make_col("status", "status")])],
855 enums: vec![EnumInfo {
856 schema_name: "public".to_string(),
857 name: "status".to_string(),
858 variants: vec!["active".to_string()],
859 default_variant: None,
860 }],
861 ..Default::default()
862 };
863 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
864 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
865 assert!(struct_file.code.contains("super::types::"));
866 }
867
868 #[test]
869 fn test_generate_extra_derives_in_struct() {
870 let schema = SchemaInfo {
871 tables: vec![make_table("users", vec![make_col("id", "int4")])],
872 ..Default::default()
873 };
874 let derives = vec!["Serialize".to_string()];
875 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
876 assert!(files[0].code.contains("Serialize"));
877 }
878
879 #[test]
880 fn test_generate_extra_derives_in_enum() {
881 let schema = SchemaInfo {
882 enums: vec![EnumInfo {
883 schema_name: "public".to_string(),
884 name: "status".to_string(),
885 variants: vec!["active".to_string()],
886 default_variant: None,
887 }],
888 ..Default::default()
889 };
890 let derives = vec!["Serialize".to_string()];
891 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
892 assert!(files[0].code.contains("Serialize"));
893 }
894
895 #[test]
896 fn test_generate_type_overrides_in_struct() {
897 let mut overrides = HashMap::new();
898 overrides.insert("jsonb".to_string(), "MyJson".to_string());
899 let schema = SchemaInfo {
900 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
901 ..Default::default()
902 };
903 let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
904 assert!(files[0].code.contains("MyJson"));
905 }
906
907 #[test]
908 fn test_generate_valid_rust_syntax() {
909 let schema = SchemaInfo {
910 tables: vec![make_table("users", vec![
911 make_col("id", "int4"),
912 make_col("name", "text"),
913 ])],
914 enums: vec![EnumInfo {
915 schema_name: "public".to_string(),
916 name: "status".to_string(),
917 variants: vec!["active".to_string(), "inactive".to_string()],
918 default_variant: None,
919 }],
920 ..Default::default()
921 };
922 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
923 for f in &files {
924 let parse_result = syn::parse_file(&f.code);
926 assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
927 }
928 }
929
930 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
933 TableInfo {
934 schema_name: "public".to_string(),
935 name: name.to_string(),
936 columns,
937 }
938 }
939
940 #[test]
941 fn test_generate_one_view() {
942 let schema = SchemaInfo {
943 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
944 ..Default::default()
945 };
946 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
947 assert_eq!(files.len(), 1);
948 assert_eq!(files[0].filename, "active_users.rs");
949 }
950
951 #[test]
952 fn test_generate_view_origin() {
953 let schema = SchemaInfo {
954 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
955 ..Default::default()
956 };
957 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
958 assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
959 }
960
961 #[test]
962 fn test_generate_tables_and_views() {
963 let schema = SchemaInfo {
964 tables: vec![make_table("users", vec![make_col("id", "int4")])],
965 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
966 ..Default::default()
967 };
968 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
969 assert_eq!(files.len(), 2);
970 }
971
972 #[test]
973 fn test_generate_view_valid_rust() {
974 let schema = SchemaInfo {
975 views: vec![make_view("active_users", vec![
976 make_col("id", "int4"),
977 make_col("name", "text"),
978 ])],
979 ..Default::default()
980 };
981 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
982 let parse_result = syn::parse_file(&files[0].code);
983 assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
984 }
985
986 #[test]
987 fn test_generate_view_nullable_column() {
988 let schema = SchemaInfo {
989 views: vec![make_view("v", vec![ColumnInfo {
990 name: "email".to_string(),
991 data_type: "text".to_string(),
992 udt_name: "text".to_string(),
993 is_nullable: true,
994 is_primary_key: false,
995 ordinal_position: 0,
996 schema_name: "public".to_string(),
997 column_default: None,
998 }])],
999 ..Default::default()
1000 };
1001 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1002 assert!(files[0].code.contains("Option<String>"));
1003 }
1004
1005 #[test]
1006 fn test_generate_collision_both_prefixed() {
1007 let schema = SchemaInfo {
1008 tables: vec![
1009 make_table("users", vec![make_col("id", "int4")]),
1010 TableInfo {
1011 schema_name: "billing".to_string(),
1012 name: "users".to_string(),
1013 columns: vec![make_col("id", "int4")],
1014 },
1015 ],
1016 ..Default::default()
1017 };
1018 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1019 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1020 assert!(filenames.contains(&"users.rs"));
1021 assert!(filenames.contains(&"billing_users.rs"));
1022 }
1023
1024 #[test]
1025 fn test_generate_no_collision_no_prefix() {
1026 let schema = SchemaInfo {
1027 tables: vec![
1028 make_table("users", vec![make_col("id", "int4")]),
1029 TableInfo {
1030 schema_name: "billing".to_string(),
1031 name: "invoices".to_string(),
1032 columns: vec![make_col("id", "int4")],
1033 },
1034 ],
1035 ..Default::default()
1036 };
1037 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1038 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1039 assert!(filenames.contains(&"users.rs"));
1040 assert!(filenames.contains(&"invoices.rs"));
1041 }
1042
1043 #[test]
1044 fn test_generate_single_schema_no_prefix() {
1045 let schema = SchemaInfo {
1046 tables: vec![
1047 make_table("users", vec![make_col("id", "int4")]),
1048 make_table("posts", vec![make_col("id", "int4")]),
1049 ],
1050 ..Default::default()
1051 };
1052 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1053 assert_eq!(files[0].filename, "users.rs");
1054 assert_eq!(files[1].filename, "posts.rs");
1055 }
1056
1057 #[test]
1058 fn test_generate_view_single_file_mode() {
1059 let schema = SchemaInfo {
1060 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1061 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1062 ..Default::default()
1063 };
1064 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
1065 assert_eq!(files.len(), 2);
1066 }
1067
1068 #[test]
1071 fn test_parse_pg_enum_default_simple() {
1072 assert_eq!(
1073 parse_pg_enum_default("'idle'::task_status"),
1074 Some("idle".to_string())
1075 );
1076 }
1077
1078 #[test]
1079 fn test_parse_pg_enum_default_schema_qualified() {
1080 assert_eq!(
1081 parse_pg_enum_default("'active'::public.task_status"),
1082 Some("active".to_string())
1083 );
1084 }
1085
1086 #[test]
1087 fn test_parse_pg_enum_default_not_enum() {
1088 assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1090 }
1091
1092 #[test]
1093 fn test_parse_pg_enum_default_no_cast() {
1094 assert_eq!(parse_pg_enum_default("'hello'"), None);
1095 }
1096
1097 #[test]
1098 fn test_parse_pg_enum_default_empty() {
1099 assert_eq!(parse_pg_enum_default(""), None);
1100 }
1101
1102 #[test]
1105 fn test_extract_enum_defaults_from_column() {
1106 let schema = SchemaInfo {
1107 tables: vec![TableInfo {
1108 schema_name: "public".to_string(),
1109 name: "tasks".to_string(),
1110 columns: vec![ColumnInfo {
1111 name: "status".to_string(),
1112 data_type: "USER-DEFINED".to_string(),
1113 udt_name: "task_status".to_string(),
1114 is_nullable: false,
1115 is_primary_key: false,
1116 ordinal_position: 0,
1117 schema_name: "public".to_string(),
1118 column_default: Some("'idle'::task_status".to_string()),
1119 }],
1120 }],
1121 enums: vec![EnumInfo {
1122 schema_name: "public".to_string(),
1123 name: "task_status".to_string(),
1124 variants: vec!["idle".to_string(), "running".to_string()],
1125 default_variant: None,
1126 }],
1127 ..Default::default()
1128 };
1129 let defaults = extract_enum_defaults(&schema);
1130 assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1131 }
1132
1133 #[test]
1134 fn test_extract_enum_defaults_no_default() {
1135 let schema = SchemaInfo {
1136 tables: vec![TableInfo {
1137 schema_name: "public".to_string(),
1138 name: "tasks".to_string(),
1139 columns: vec![ColumnInfo {
1140 name: "status".to_string(),
1141 data_type: "USER-DEFINED".to_string(),
1142 udt_name: "task_status".to_string(),
1143 is_nullable: false,
1144 is_primary_key: false,
1145 ordinal_position: 0,
1146 schema_name: "public".to_string(),
1147 column_default: None,
1148 }],
1149 }],
1150 enums: vec![EnumInfo {
1151 schema_name: "public".to_string(),
1152 name: "task_status".to_string(),
1153 variants: vec!["idle".to_string()],
1154 default_variant: None,
1155 }],
1156 ..Default::default()
1157 };
1158 let defaults = extract_enum_defaults(&schema);
1159 assert!(defaults.is_empty());
1160 }
1161
1162 #[test]
1163 fn test_extract_enum_defaults_non_enum_column_ignored() {
1164 let schema = SchemaInfo {
1165 tables: vec![TableInfo {
1166 schema_name: "public".to_string(),
1167 name: "users".to_string(),
1168 columns: vec![ColumnInfo {
1169 name: "name".to_string(),
1170 data_type: "character varying".to_string(),
1171 udt_name: "varchar".to_string(),
1172 is_nullable: false,
1173 is_primary_key: false,
1174 ordinal_position: 0,
1175 schema_name: "public".to_string(),
1176 column_default: Some("'hello'::character varying".to_string()),
1177 }],
1178 }],
1179 enums: vec![],
1180 ..Default::default()
1181 };
1182 let defaults = extract_enum_defaults(&schema);
1183 assert!(defaults.is_empty());
1184 }
1185
1186 #[test]
1187 fn test_generate_enum_with_default() {
1188 let schema = SchemaInfo {
1189 tables: vec![TableInfo {
1190 schema_name: "public".to_string(),
1191 name: "tasks".to_string(),
1192 columns: vec![ColumnInfo {
1193 name: "status".to_string(),
1194 data_type: "USER-DEFINED".to_string(),
1195 udt_name: "task_status".to_string(),
1196 is_nullable: false,
1197 is_primary_key: false,
1198 ordinal_position: 0,
1199 schema_name: "public".to_string(),
1200 column_default: Some("'idle'::task_status".to_string()),
1201 }],
1202 }],
1203 enums: vec![EnumInfo {
1204 schema_name: "public".to_string(),
1205 name: "task_status".to_string(),
1206 variants: vec!["idle".to_string(), "running".to_string()],
1207 default_variant: None,
1208 }],
1209 ..Default::default()
1210 };
1211 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1212 let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1213 assert!(types_file.code.contains("impl Default for TaskStatus"));
1214 assert!(types_file.code.contains("Self::Idle"));
1215 }
1216}