1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod identifiers;
7pub mod naming;
8pub mod struct_gen;
9
10use std::collections::{BTreeSet, HashMap};
11use std::path::Path;
12
13use proc_macro2::TokenStream;
14
15use crate::cli::{DatabaseKind, TimeCrate};
16use crate::introspect::SchemaInfo;
17
18const RUST_KEYWORDS: &[&str] = &[
20 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
21 "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub",
22 "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type",
23 "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", "final",
24 "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
25];
26
27pub fn is_rust_keyword(name: &str) -> bool {
29 RUST_KEYWORDS.contains(&name)
30}
31
32pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
34 let mut imports = Vec::new();
35 let has = |name: &str| extra_derives.iter().any(|d| d == name);
36 if has("Serialize") || has("Deserialize") {
37 let mut parts = Vec::new();
38 if has("Serialize") {
39 parts.push("Serialize");
40 }
41 if has("Deserialize") {
42 parts.push("Deserialize");
43 }
44 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
45 }
46 imports
47}
48
49pub fn normalize_module_name(name: &str) -> String {
52 let mut result = String::with_capacity(name.len());
53 let mut prev_underscore = false;
54 for c in name.chars() {
55 if c == '_' {
56 if !prev_underscore {
57 result.push(c);
58 }
59 prev_underscore = true;
60 } else {
61 prev_underscore = false;
62 result.push(c);
63 }
64 }
65 result
66}
67
68const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
70
71pub fn is_default_schema(schema: &str) -> bool {
73 DEFAULT_SCHEMAS.contains(&schema)
74}
75
76pub fn rust_type_name_for(schema_info: &SchemaInfo, schema: &str, name: &str) -> String {
83 use heck::ToUpperCamelCase;
84 if type_name_has_cross_schema_collision(schema_info, name) && !is_default_schema(schema) {
85 format!(
86 "{}{}",
87 schema.to_upper_camel_case(),
88 name.to_upper_camel_case()
89 )
90 } else {
91 name.to_upper_camel_case()
92 }
93}
94
95pub fn required_pg_search_path(schema_info: &SchemaInfo) -> Vec<String> {
108 let mut schemas: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
109 for e in &schema_info.enums {
110 if !is_default_schema(&e.schema_name) {
111 schemas.insert(e.schema_name.clone());
112 }
113 }
114 for c in &schema_info.composite_types {
115 if !is_default_schema(&c.schema_name) {
116 schemas.insert(c.schema_name.clone());
117 }
118 }
119 for d in &schema_info.domains {
120 if !is_default_schema(&d.schema_name) {
121 schemas.insert(d.schema_name.clone());
122 }
123 }
124 schemas.into_iter().collect()
125}
126
127pub fn type_name_has_cross_schema_collision(schema_info: &SchemaInfo, name: &str) -> bool {
130 let mut schemas: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
131 schemas.extend(
132 schema_info
133 .enums
134 .iter()
135 .filter(|e| e.name == name)
136 .map(|e| e.schema_name.as_str()),
137 );
138 schemas.extend(
139 schema_info
140 .composite_types
141 .iter()
142 .filter(|c| c.name == name)
143 .map(|c| c.schema_name.as_str()),
144 );
145 schemas.extend(
146 schema_info
147 .domains
148 .iter()
149 .filter(|d| d.name == name)
150 .map(|d| d.schema_name.as_str()),
151 );
152 schemas.len() > 1
153}
154
155pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
158 if name_collides && !is_default_schema(schema_name) {
159 normalize_module_name(&format!("{}_{}", schema_name, table_name))
160 } else {
161 normalize_module_name(table_name)
162 }
163}
164
165fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
167 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
168 for t in &schema_info.tables {
169 seen.entry(t.name.as_str())
170 .or_default()
171 .insert(t.schema_name.as_str());
172 }
173 for v in &schema_info.views {
174 seen.entry(v.name.as_str())
175 .or_default()
176 .insert(v.schema_name.as_str());
177 }
178 seen.into_iter()
179 .filter(|(_, schemas)| schemas.len() > 1)
180 .map(|(name, _)| name)
181 .collect()
182}
183
184#[derive(Debug, Clone)]
186pub struct GeneratedFile {
187 pub filename: String,
188 pub origin: Option<String>,
190 pub code: String,
191}
192
193pub fn generate(
195 schema_info: &SchemaInfo,
196 db_kind: DatabaseKind,
197 extra_derives: &[String],
198 type_overrides: &HashMap<String, String>,
199 single_file: bool,
200 time_crate: TimeCrate,
201) -> crate::error::Result<Vec<GeneratedFile>> {
202 generate_with_domain_style(
203 schema_info,
204 db_kind,
205 extra_derives,
206 type_overrides,
207 single_file,
208 time_crate,
209 crate::cli::DomainStyle::Alias,
210 )
211}
212
213pub fn generate_with_domain_style(
216 schema_info: &SchemaInfo,
217 db_kind: DatabaseKind,
218 extra_derives: &[String],
219 type_overrides: &HashMap<String, String>,
220 single_file: bool,
221 time_crate: TimeCrate,
222 domain_style: crate::cli::DomainStyle,
223) -> crate::error::Result<Vec<GeneratedFile>> {
224 let mut files = Vec::new();
225
226 let colliding_names = find_colliding_names(schema_info);
228
229 for table in &schema_info.tables {
231 let (tokens, imports) = struct_gen::generate_struct(
232 table,
233 db_kind,
234 schema_info,
235 extra_derives,
236 type_overrides,
237 false,
238 time_crate,
239 );
240 let imports = filter_imports(&imports, single_file);
241 let code = format_tokens_with_imports(&tokens, &imports)?;
242 let module_name = build_module_name(
243 &table.schema_name,
244 &table.name,
245 colliding_names.contains(table.name.as_str()),
246 );
247 files.push(GeneratedFile {
248 filename: format!("{}.rs", module_name),
249 origin: None,
250 code,
251 });
252 }
253
254 for view in &schema_info.views {
256 let (tokens, imports) = struct_gen::generate_struct(
257 view,
258 db_kind,
259 schema_info,
260 extra_derives,
261 type_overrides,
262 true,
263 time_crate,
264 );
265 let imports = filter_imports(&imports, single_file);
266 let code = format_tokens_with_imports(&tokens, &imports)?;
267 let module_name = build_module_name(
268 &view.schema_name,
269 &view.name,
270 colliding_names.contains(view.name.as_str()),
271 );
272 files.push(GeneratedFile {
273 filename: format!("{}.rs", module_name),
274 origin: None,
275 code,
276 });
277 }
278
279 let mut types_blocks: Vec<String> = Vec::new();
282 let mut types_imports = BTreeSet::new();
283
284 let enum_defaults = extract_enum_defaults(schema_info);
286 for enum_info in &schema_info.enums {
287 enum_gen::check_variant_collisions(enum_info)?;
288 let mut enriched = enum_info.clone();
289 if enriched.default_variant.is_none() {
290 if let Some(default) = enum_defaults.get(&enum_info.name) {
291 enriched.default_variant = Some(default.clone());
292 }
293 }
294 let (tokens, imports) =
295 enum_gen::generate_enum_with_schema(&enriched, db_kind, extra_derives, schema_info);
296 types_blocks.push(format_tokens(&tokens)?);
297 types_imports.extend(imports);
298 }
299
300 for composite in &schema_info.composite_types {
301 let (tokens, imports) = composite_gen::generate_composite(
302 composite,
303 db_kind,
304 schema_info,
305 extra_derives,
306 type_overrides,
307 time_crate,
308 );
309 types_blocks.push(format_tokens(&tokens)?);
310 types_imports.extend(imports);
311 }
312
313 for domain in &schema_info.domains {
314 let (tokens, imports) = domain_gen::generate_domain_with_style(
315 domain,
316 db_kind,
317 schema_info,
318 type_overrides,
319 time_crate,
320 domain_style,
321 );
322 types_blocks.push(format_tokens(&tokens)?);
323 types_imports.extend(imports);
324 }
325
326 if !types_blocks.is_empty() {
327 let import_lines: String = types_imports.iter().map(|i| format!("{}\n", i)).collect();
328 let body = types_blocks.join("\n");
329 let code = if import_lines.is_empty() {
330 body
331 } else {
332 format!("{}\n\n{}", import_lines.trim_end(), body)
333 };
334 files.push(GeneratedFile {
335 filename: "types.rs".to_string(),
336 origin: None,
337 code,
338 });
339 }
340
341 Ok(files)
342}
343
344fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
347 let mut defaults: HashMap<String, String> = HashMap::new();
348
349 let all_columns = schema_info
350 .tables
351 .iter()
352 .chain(schema_info.views.iter())
353 .flat_map(|t| t.columns.iter());
354
355 for col in all_columns {
356 let default_expr = match &col.column_default {
357 Some(d) => d,
358 None => continue,
359 };
360
361 let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
363
364 let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
366 if enum_match.is_none() {
367 continue;
368 }
369
370 if let Some(variant) = parse_pg_enum_default(default_expr) {
372 defaults.entry(base_udt.to_string()).or_insert(variant);
373 }
374 }
375
376 defaults
377}
378
379fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
382 let after_opening = default_expr.trim().strip_prefix('\'')?;
384 let end_quote = after_opening.find('\'')?;
385 let value = &after_opening[..end_quote];
386 let rest = &after_opening[end_quote + 1..];
387 if rest.starts_with("::") {
388 return Some(value.to_string());
389 }
390 None
391}
392
393fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
395 if single_file {
396 imports
397 .iter()
398 .filter(|i| !i.contains("super::types::"))
399 .cloned()
400 .collect()
401 } else {
402 imports.clone()
403 }
404}
405
406pub fn detect_tab_spaces(start_dir: &Path) -> usize {
409 let mut dir = if start_dir.is_file() {
410 start_dir.parent().unwrap_or(start_dir)
411 } else {
412 start_dir
413 };
414 loop {
415 for name in &["rustfmt.toml", ".rustfmt.toml"] {
416 let candidate = dir.join(name);
417 if let Ok(content) = std::fs::read_to_string(&candidate) {
418 for line in content.lines() {
419 let line = line.trim();
420 if let Some(rest) = line.strip_prefix("tab_spaces") {
421 let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
422 if let Ok(n) = rest.trim().parse::<usize>() {
423 return n;
424 }
425 }
426 }
427 return 4;
429 }
430 }
431 match dir.parent() {
432 Some(parent) => dir = parent,
433 None => return 4,
434 }
435 }
436}
437
438pub(crate) fn parse_and_format(tokens: &TokenStream) -> crate::error::Result<String> {
441 parse_and_format_with_tab_spaces(tokens, 4)
442}
443
444pub(crate) fn parse_and_format_with_tab_spaces(
445 tokens: &TokenStream,
446 tab_spaces: usize,
447) -> crate::error::Result<String> {
448 let file = syn::parse2::<syn::File>(tokens.clone()).map_err(|e| {
449 crate::error::Error::Config(format!(
450 "Internal sqlx-gen bug: failed to parse generated code: {}. \
451 Raw tokens:\n {}\n\
452 Please report this with the input schema.",
453 e, tokens
454 ))
455 })?;
456 let raw = prettyplease::unparse(&file);
457 let raw = indent_multiline_raw_strings(&raw, tab_spaces);
458 Ok(add_blank_lines_between_items(&raw))
459}
460
461pub(crate) fn format_tokens(tokens: &TokenStream) -> crate::error::Result<String> {
463 parse_and_format(tokens)
464}
465
466pub fn format_tokens_with_imports(
467 tokens: &TokenStream,
468 imports: &BTreeSet<String>,
469) -> crate::error::Result<String> {
470 format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
471}
472
473pub fn format_tokens_with_imports_and_tab_spaces(
474 tokens: &TokenStream,
475 imports: &BTreeSet<String>,
476 tab_spaces: usize,
477) -> crate::error::Result<String> {
478 let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces)?;
479
480 let used_imports: Vec<&String> = imports
481 .iter()
482 .filter(|imp| is_import_used(imp, &formatted))
483 .collect();
484
485 if used_imports.is_empty() {
486 Ok(formatted)
487 } else {
488 let import_lines: String = used_imports.iter().map(|i| format!("{}\n", i)).collect();
489 Ok(format!("{}\n\n{}", import_lines.trim_end(), formatted))
490 }
491}
492
493fn is_import_used(import: &str, code: &str) -> bool {
496 let trimmed = import.trim().trim_end_matches(';');
500 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
501
502 if path.ends_with("::*") {
503 return true;
504 }
505
506 if let Some(start) = path.find('{') {
508 if let Some(end) = path.find('}') {
509 let names = &path[start + 1..end];
510 return names
511 .split(',')
512 .map(|n| n.trim())
513 .filter(|n| !n.is_empty())
514 .any(|name| code.contains(name));
515 }
516 }
517
518 if let Some(name) = path.rsplit("::").next() {
520 return code.contains(name);
521 }
522
523 true
524}
525
526fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
532 let close_indent = 4 + tab_spaces; let sql_indent = 4 + 2 * tab_spaces; let lines: Vec<&str> = code.lines().collect();
540 let mut result = Vec::with_capacity(lines.len());
541 let mut inside_raw = false;
542 let mut raw_lines: Vec<&str> = Vec::new();
543
544 for line in &lines {
545 if !inside_raw {
546 if let Some(pos) = line.find("r#\"") {
547 let after = &line[pos + 3..];
548 if !after.contains("\"#") {
549 inside_raw = true;
550 raw_lines.clear();
551 }
552 }
553 result.push(line.to_string());
554 } else if line.trim_start().starts_with("\"#") {
555 let min_indent = raw_lines
557 .iter()
558 .filter(|l| !l.trim().is_empty())
559 .map(|l| l.len() - l.trim_start().len())
560 .min()
561 .unwrap_or(0);
562 for raw_line in &raw_lines {
563 let trimmed = raw_line.trim();
564 if trimmed.is_empty() {
565 result.push(String::new());
566 } else {
567 let original_indent = raw_line.len() - raw_line.trim_start().len();
568 let relative = original_indent.saturating_sub(min_indent);
569 result.push(format!(
570 "{}{}{}",
571 " ".repeat(sql_indent),
572 " ".repeat(relative),
573 trimmed
574 ));
575 }
576 }
577 let trimmed = line.trim();
579 result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
580 inside_raw = false;
581 } else {
582 raw_lines.push(line);
583 }
584 }
585
586 result.join("\n")
587}
588
589fn add_blank_lines_between_items(code: &str) -> String {
590 let lines: Vec<&str> = code.lines().collect();
591 let mut result = Vec::with_capacity(lines.len());
592
593 for (i, line) in lines.iter().enumerate() {
594 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
597 let prev = lines[i - 1].trim();
598 if prev.ends_with(',') {
599 result.push("");
600 }
601 }
602
603 if i > 0 {
606 let trimmed = line.trim();
607 let prev = lines[i - 1].trim();
608 if prev == "}"
609 && (trimmed.starts_with("pub struct")
610 || trimmed.starts_with("impl ")
611 || trimmed.starts_with("#[derive")
612 || trimmed.starts_with("pub async fn")
613 || trimmed.starts_with("pub fn"))
614 {
615 result.push("");
616 }
617 }
618
619 if i > 0 {
623 let trimmed = line.trim();
624 let prev = lines[i - 1].trim();
625 let prev_is_await_end = prev.ends_with(".await?;")
626 || prev.ends_with(".await?")
627 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
628 if prev_is_await_end && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) {
629 result.push("");
630 }
631 if trimmed.starts_with("let ")
633 && trimmed.contains("sqlx::")
634 && prev.starts_with("let ")
635 && !prev.contains("sqlx::")
636 {
637 result.push("");
638 }
639 }
640
641 result.push(line);
642 }
643
644 result.join("\n")
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use crate::introspect::{
651 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
652 };
653 use std::collections::HashMap;
654
655 #[test]
658 fn test_keyword_type() {
659 assert!(is_rust_keyword("type"));
660 }
661
662 #[test]
663 fn test_keyword_fn() {
664 assert!(is_rust_keyword("fn"));
665 }
666
667 #[test]
668 fn test_keyword_let() {
669 assert!(is_rust_keyword("let"));
670 }
671
672 #[test]
673 fn test_keyword_match() {
674 assert!(is_rust_keyword("match"));
675 }
676
677 #[test]
678 fn test_keyword_async() {
679 assert!(is_rust_keyword("async"));
680 }
681
682 #[test]
683 fn test_keyword_await() {
684 assert!(is_rust_keyword("await"));
685 }
686
687 #[test]
688 fn test_keyword_yield() {
689 assert!(is_rust_keyword("yield"));
690 }
691
692 #[test]
693 fn test_keyword_abstract() {
694 assert!(is_rust_keyword("abstract"));
695 }
696
697 #[test]
698 fn test_keyword_try() {
699 assert!(is_rust_keyword("try"));
700 }
701
702 #[test]
703 fn test_not_keyword_name() {
704 assert!(!is_rust_keyword("name"));
705 }
706
707 #[test]
708 fn test_not_keyword_id() {
709 assert!(!is_rust_keyword("id"));
710 }
711
712 #[test]
713 fn test_not_keyword_uppercase_type() {
714 assert!(!is_rust_keyword("Type"));
715 }
716
717 #[test]
720 fn test_normalize_no_underscores() {
721 assert_eq!(normalize_module_name("users"), "users");
722 }
723
724 #[test]
725 fn test_normalize_single_underscore() {
726 assert_eq!(normalize_module_name("user_roles"), "user_roles");
727 }
728
729 #[test]
730 fn test_normalize_double_underscore() {
731 assert_eq!(normalize_module_name("user__roles"), "user_roles");
732 }
733
734 #[test]
735 fn test_normalize_triple_underscore() {
736 assert_eq!(normalize_module_name("a___b"), "a_b");
737 }
738
739 #[test]
740 fn test_normalize_leading_underscore() {
741 assert_eq!(normalize_module_name("_private"), "_private");
742 }
743
744 #[test]
745 fn test_normalize_trailing_underscore() {
746 assert_eq!(normalize_module_name("name_"), "name_");
747 }
748
749 #[test]
750 fn test_normalize_double_leading() {
751 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
752 }
753
754 #[test]
755 fn test_normalize_multiple_groups() {
756 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
757 }
758
759 #[test]
762 fn test_build_no_collision_no_prefix() {
763 assert_eq!(build_module_name("public", "users", false), "users");
764 }
765
766 #[test]
767 fn test_build_no_collision_non_default_no_prefix() {
768 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
769 }
770
771 #[test]
772 fn test_build_collision_prefixed() {
773 assert_eq!(build_module_name("billing", "users", true), "billing_users");
774 }
775
776 #[test]
777 fn test_build_collision_default_schema_no_prefix() {
778 assert_eq!(build_module_name("public", "users", true), "users");
779 }
780
781 #[test]
782 fn test_build_collision_normalizes_double_underscore() {
783 assert_eq!(
784 build_module_name("billing", "agent__connector", true),
785 "billing_agent_connector"
786 );
787 }
788
789 #[test]
792 fn test_default_schema_public() {
793 assert!(is_default_schema("public"));
794 }
795
796 #[test]
797 fn test_default_schema_main() {
798 assert!(is_default_schema("main"));
799 }
800
801 #[test]
802 fn test_non_default_schema() {
803 assert!(!is_default_schema("billing"));
804 }
805
806 #[test]
809 fn test_imports_empty() {
810 let result = imports_for_derives(&[]);
811 assert!(result.is_empty());
812 }
813
814 #[test]
815 fn test_imports_serialize_only() {
816 let derives = vec!["Serialize".to_string()];
817 let result = imports_for_derives(&derives);
818 assert_eq!(result, vec!["use serde::{Serialize};"]);
819 }
820
821 #[test]
822 fn test_imports_deserialize_only() {
823 let derives = vec!["Deserialize".to_string()];
824 let result = imports_for_derives(&derives);
825 assert_eq!(result, vec!["use serde::{Deserialize};"]);
826 }
827
828 #[test]
829 fn test_imports_both_serde() {
830 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
831 let result = imports_for_derives(&derives);
832 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
833 }
834
835 #[test]
836 fn test_imports_non_serde() {
837 let derives = vec!["Hash".to_string()];
838 let result = imports_for_derives(&derives);
839 assert!(result.is_empty());
840 }
841
842 #[test]
843 fn test_imports_non_serde_multiple() {
844 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
845 let result = imports_for_derives(&derives);
846 assert!(result.is_empty());
847 }
848
849 #[test]
850 fn test_imports_mixed_serde_and_others() {
851 let derives = vec![
852 "Serialize".to_string(),
853 "Hash".to_string(),
854 "Deserialize".to_string(),
855 ];
856 let result = imports_for_derives(&derives);
857 assert_eq!(result.len(), 1);
858 assert!(result[0].contains("Serialize"));
859 assert!(result[0].contains("Deserialize"));
860 }
861
862 #[test]
865 fn test_blank_lines_between_renamed_variants() {
866 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
867 let result = add_blank_lines_between_items(input);
868 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
869 }
870
871 #[test]
872 fn test_no_blank_line_for_first_variant() {
873 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
874 let result = add_blank_lines_between_items(input);
875 assert!(!result.contains("{\n\n"));
877 }
878
879 #[test]
880 fn test_no_change_without_rename() {
881 let input = "pub enum Foo {\n A,\n B,\n}";
882 let result = add_blank_lines_between_items(input);
883 assert_eq!(result, input);
884 }
885
886 #[test]
887 fn test_no_change_for_struct() {
888 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
889 let result = add_blank_lines_between_items(input);
890 assert_eq!(result, input);
891 }
892
893 fn schema_with_two_role_enums() -> SchemaInfo {
896 SchemaInfo {
897 enums: vec![
898 crate::introspect::EnumInfo {
899 schema_name: "auth".into(),
900 name: "role".into(),
901 variants: vec!["admin".into(), "user".into()],
902 default_variant: None,
903 },
904 crate::introspect::EnumInfo {
905 schema_name: "billing".into(),
906 name: "role".into(),
907 variants: vec!["payer".into(), "payee".into()],
908 default_variant: None,
909 },
910 ],
911 ..Default::default()
912 }
913 }
914
915 #[test]
916 fn rust_type_name_prefixes_schema_on_cross_schema_collision() {
917 let s = schema_with_two_role_enums();
918 assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
919 assert_eq!(rust_type_name_for(&s, "billing", "role"), "BillingRole");
920 }
921
922 #[test]
923 fn rust_type_name_keeps_bare_name_when_unique() {
924 let s = SchemaInfo {
925 enums: vec![crate::introspect::EnumInfo {
926 schema_name: "auth".into(),
927 name: "role".into(),
928 variants: vec!["admin".into()],
929 default_variant: None,
930 }],
931 ..Default::default()
932 };
933 assert_eq!(rust_type_name_for(&s, "auth", "role"), "Role");
934 }
935
936 #[test]
937 fn required_search_path_collects_non_default_schemas() {
938 let s = SchemaInfo {
939 enums: vec![
940 crate::introspect::EnumInfo {
941 schema_name: "auth".into(),
942 name: "role".into(),
943 variants: vec!["x".into()],
944 default_variant: None,
945 },
946 crate::introspect::EnumInfo {
947 schema_name: "public".into(),
948 name: "status".into(),
949 variants: vec!["y".into()],
950 default_variant: None,
951 },
952 ],
953 composite_types: vec![crate::introspect::CompositeTypeInfo {
954 schema_name: "billing".into(),
955 name: "addr".into(),
956 fields: vec![],
957 }],
958 domains: vec![crate::introspect::DomainInfo {
959 schema_name: "auth".into(),
960 name: "email".into(),
961 base_type: "text".into(),
962 }],
963 ..Default::default()
964 };
965 assert_eq!(required_pg_search_path(&s), vec!["auth", "billing"]);
967 }
968
969 #[test]
970 fn required_search_path_empty_when_only_default_schema() {
971 let s = SchemaInfo {
972 enums: vec![crate::introspect::EnumInfo {
973 schema_name: "public".into(),
974 name: "status".into(),
975 variants: vec!["y".into()],
976 default_variant: None,
977 }],
978 ..Default::default()
979 };
980 assert!(required_pg_search_path(&s).is_empty());
981 }
982
983 #[test]
984 fn rust_type_name_default_schema_keeps_bare_name_even_on_collision() {
985 let s = SchemaInfo {
986 enums: vec![
987 crate::introspect::EnumInfo {
988 schema_name: "public".into(),
989 name: "role".into(),
990 variants: vec!["a".into()],
991 default_variant: None,
992 },
993 crate::introspect::EnumInfo {
994 schema_name: "auth".into(),
995 name: "role".into(),
996 variants: vec!["b".into()],
997 default_variant: None,
998 },
999 ],
1000 ..Default::default()
1001 };
1002 assert_eq!(rust_type_name_for(&s, "public", "role"), "Role");
1004 assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
1005 }
1006
1007 #[test]
1010 fn test_filter_single_file_strips_super_types() {
1011 let mut imports = BTreeSet::new();
1012 imports.insert("use super::types::Foo;".to_string());
1013 imports.insert("use chrono::NaiveDateTime;".to_string());
1014 let result = filter_imports(&imports, true);
1015 assert!(!result.contains("use super::types::Foo;"));
1016 assert!(result.contains("use chrono::NaiveDateTime;"));
1017 }
1018
1019 #[test]
1020 fn test_filter_single_file_keeps_other_imports() {
1021 let mut imports = BTreeSet::new();
1022 imports.insert("use chrono::NaiveDateTime;".to_string());
1023 let result = filter_imports(&imports, true);
1024 assert!(result.contains("use chrono::NaiveDateTime;"));
1025 }
1026
1027 #[test]
1028 fn test_filter_multi_file_keeps_all() {
1029 let mut imports = BTreeSet::new();
1030 imports.insert("use super::types::Foo;".to_string());
1031 imports.insert("use chrono::NaiveDateTime;".to_string());
1032 let result = filter_imports(&imports, false);
1033 assert_eq!(result.len(), 2);
1034 }
1035
1036 #[test]
1037 fn test_filter_empty_set() {
1038 let imports = BTreeSet::new();
1039 let result = filter_imports(&imports, true);
1040 assert!(result.is_empty());
1041 }
1042
1043 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1046 TableInfo {
1047 schema_name: "public".to_string(),
1048 name: name.to_string(),
1049 columns,
1050 }
1051 }
1052
1053 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
1054 ColumnInfo {
1055 name: name.to_string(),
1056 data_type: udt_name.to_string(),
1057 udt_name: udt_name.to_string(),
1058 is_nullable: false,
1059 is_primary_key: false,
1060 ordinal_position: 0,
1061 schema_name: "public".to_string(),
1062 udt_schema: None,
1063 column_default: None,
1064 }
1065 }
1066
1067 #[test]
1068 fn test_generate_empty_schema() {
1069 let schema = SchemaInfo::default();
1070 let files = generate(
1071 &schema,
1072 DatabaseKind::Postgres,
1073 &[],
1074 &HashMap::new(),
1075 false,
1076 TimeCrate::Chrono,
1077 )
1078 .unwrap();
1079 assert!(files.is_empty());
1080 }
1081
1082 #[test]
1083 fn test_generate_one_table() {
1084 let schema = SchemaInfo {
1085 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1086 ..Default::default()
1087 };
1088 let files = generate(
1089 &schema,
1090 DatabaseKind::Postgres,
1091 &[],
1092 &HashMap::new(),
1093 false,
1094 TimeCrate::Chrono,
1095 )
1096 .unwrap();
1097 assert_eq!(files.len(), 1);
1098 assert_eq!(files[0].filename, "users.rs");
1099 }
1100
1101 #[test]
1102 fn test_generate_two_tables() {
1103 let schema = SchemaInfo {
1104 tables: vec![
1105 make_table("users", vec![make_col("id", "int4")]),
1106 make_table("posts", vec![make_col("id", "int4")]),
1107 ],
1108 ..Default::default()
1109 };
1110 let files = generate(
1111 &schema,
1112 DatabaseKind::Postgres,
1113 &[],
1114 &HashMap::new(),
1115 false,
1116 TimeCrate::Chrono,
1117 )
1118 .unwrap();
1119 assert_eq!(files.len(), 2);
1120 }
1121
1122 #[test]
1123 fn test_generate_enum_creates_types_file() {
1124 let schema = SchemaInfo {
1125 enums: vec![EnumInfo {
1126 schema_name: "public".to_string(),
1127 name: "status".to_string(),
1128 variants: vec!["active".to_string(), "inactive".to_string()],
1129 default_variant: None,
1130 }],
1131 ..Default::default()
1132 };
1133 let files = generate(
1134 &schema,
1135 DatabaseKind::Postgres,
1136 &[],
1137 &HashMap::new(),
1138 false,
1139 TimeCrate::Chrono,
1140 )
1141 .unwrap();
1142 assert_eq!(files.len(), 1);
1143 assert_eq!(files[0].filename, "types.rs");
1144 }
1145
1146 #[test]
1147 fn test_generate_enums_composites_domains_single_types_file() {
1148 let schema = SchemaInfo {
1149 enums: vec![EnumInfo {
1150 schema_name: "public".to_string(),
1151 name: "status".to_string(),
1152 variants: vec!["active".to_string()],
1153 default_variant: None,
1154 }],
1155 composite_types: vec![CompositeTypeInfo {
1156 schema_name: "public".to_string(),
1157 name: "address".to_string(),
1158 fields: vec![make_col("street", "text")],
1159 }],
1160 domains: vec![DomainInfo {
1161 schema_name: "public".to_string(),
1162 name: "email".to_string(),
1163 base_type: "text".to_string(),
1164 }],
1165 ..Default::default()
1166 };
1167 let files = generate(
1168 &schema,
1169 DatabaseKind::Postgres,
1170 &[],
1171 &HashMap::new(),
1172 false,
1173 TimeCrate::Chrono,
1174 )
1175 .unwrap();
1176 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
1178 assert_eq!(types_files.len(), 1);
1179 }
1180
1181 #[test]
1182 fn test_generate_tables_and_enums() {
1183 let schema = SchemaInfo {
1184 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1185 enums: vec![EnumInfo {
1186 schema_name: "public".to_string(),
1187 name: "status".to_string(),
1188 variants: vec!["active".to_string()],
1189 default_variant: None,
1190 }],
1191 ..Default::default()
1192 };
1193 let files = generate(
1194 &schema,
1195 DatabaseKind::Postgres,
1196 &[],
1197 &HashMap::new(),
1198 false,
1199 TimeCrate::Chrono,
1200 )
1201 .unwrap();
1202 assert_eq!(files.len(), 2); }
1204
1205 #[test]
1206 fn test_generate_filename_normalized() {
1207 let schema = SchemaInfo {
1208 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
1209 ..Default::default()
1210 };
1211 let files = generate(
1212 &schema,
1213 DatabaseKind::Postgres,
1214 &[],
1215 &HashMap::new(),
1216 false,
1217 TimeCrate::Chrono,
1218 )
1219 .unwrap();
1220 assert_eq!(files[0].filename, "user_data.rs");
1221 }
1222
1223 #[test]
1224 fn test_generate_no_origin_for_tables() {
1225 let schema = SchemaInfo {
1226 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1227 ..Default::default()
1228 };
1229 let files = generate(
1230 &schema,
1231 DatabaseKind::Postgres,
1232 &[],
1233 &HashMap::new(),
1234 false,
1235 TimeCrate::Chrono,
1236 )
1237 .unwrap();
1238 assert_eq!(files[0].origin, None);
1239 }
1240
1241 #[test]
1242 fn test_generate_types_no_origin() {
1243 let schema = SchemaInfo {
1244 enums: vec![EnumInfo {
1245 schema_name: "public".to_string(),
1246 name: "status".to_string(),
1247 variants: vec!["active".to_string()],
1248 default_variant: None,
1249 }],
1250 ..Default::default()
1251 };
1252 let files = generate(
1253 &schema,
1254 DatabaseKind::Postgres,
1255 &[],
1256 &HashMap::new(),
1257 false,
1258 TimeCrate::Chrono,
1259 )
1260 .unwrap();
1261 assert_eq!(files[0].origin, None);
1262 }
1263
1264 #[test]
1265 fn test_generate_single_file_filters_super_types_imports() {
1266 let schema = SchemaInfo {
1267 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1268 enums: vec![EnumInfo {
1269 schema_name: "public".to_string(),
1270 name: "status".to_string(),
1271 variants: vec!["active".to_string()],
1272 default_variant: None,
1273 }],
1274 ..Default::default()
1275 };
1276 let files = generate(
1277 &schema,
1278 DatabaseKind::Postgres,
1279 &[],
1280 &HashMap::new(),
1281 true,
1282 TimeCrate::Chrono,
1283 )
1284 .unwrap();
1285 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1287 assert!(!struct_file.code.contains("super::types::"));
1288 }
1289
1290 #[test]
1291 fn test_generate_multi_file_keeps_super_types_imports() {
1292 let schema = SchemaInfo {
1294 tables: vec![make_table("users", vec![make_col("status", "status")])],
1295 enums: vec![EnumInfo {
1296 schema_name: "public".to_string(),
1297 name: "status".to_string(),
1298 variants: vec!["active".to_string()],
1299 default_variant: None,
1300 }],
1301 ..Default::default()
1302 };
1303 let files = generate(
1304 &schema,
1305 DatabaseKind::Postgres,
1306 &[],
1307 &HashMap::new(),
1308 false,
1309 TimeCrate::Chrono,
1310 )
1311 .unwrap();
1312 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1313 assert!(struct_file.code.contains("super::types::"));
1314 }
1315
1316 #[test]
1317 fn test_generate_extra_derives_in_struct() {
1318 let schema = SchemaInfo {
1319 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1320 ..Default::default()
1321 };
1322 let derives = vec!["Serialize".to_string()];
1323 let files = generate(
1324 &schema,
1325 DatabaseKind::Postgres,
1326 &derives,
1327 &HashMap::new(),
1328 false,
1329 TimeCrate::Chrono,
1330 )
1331 .unwrap();
1332 assert!(files[0].code.contains("Serialize"));
1333 }
1334
1335 #[test]
1336 fn test_generate_extra_derives_in_enum() {
1337 let schema = SchemaInfo {
1338 enums: vec![EnumInfo {
1339 schema_name: "public".to_string(),
1340 name: "status".to_string(),
1341 variants: vec!["active".to_string()],
1342 default_variant: None,
1343 }],
1344 ..Default::default()
1345 };
1346 let derives = vec!["Serialize".to_string()];
1347 let files = generate(
1348 &schema,
1349 DatabaseKind::Postgres,
1350 &derives,
1351 &HashMap::new(),
1352 false,
1353 TimeCrate::Chrono,
1354 )
1355 .unwrap();
1356 assert!(files[0].code.contains("Serialize"));
1357 }
1358
1359 #[test]
1360 fn test_generate_type_overrides_in_struct() {
1361 let mut overrides = HashMap::new();
1362 overrides.insert("jsonb".to_string(), "MyJson".to_string());
1363 let schema = SchemaInfo {
1364 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1365 ..Default::default()
1366 };
1367 let files = generate(
1368 &schema,
1369 DatabaseKind::Postgres,
1370 &[],
1371 &overrides,
1372 false,
1373 TimeCrate::Chrono,
1374 )
1375 .unwrap();
1376 assert!(files[0].code.contains("MyJson"));
1377 }
1378
1379 #[test]
1380 fn test_generate_valid_rust_syntax() {
1381 let schema = SchemaInfo {
1382 tables: vec![make_table(
1383 "users",
1384 vec![make_col("id", "int4"), make_col("name", "text")],
1385 )],
1386 enums: vec![EnumInfo {
1387 schema_name: "public".to_string(),
1388 name: "status".to_string(),
1389 variants: vec!["active".to_string(), "inactive".to_string()],
1390 default_variant: None,
1391 }],
1392 ..Default::default()
1393 };
1394 let files = generate(
1395 &schema,
1396 DatabaseKind::Postgres,
1397 &[],
1398 &HashMap::new(),
1399 false,
1400 TimeCrate::Chrono,
1401 )
1402 .unwrap();
1403 for f in &files {
1404 let parse_result = syn::parse_file(&f.code);
1406 assert!(
1407 parse_result.is_ok(),
1408 "Failed to parse {}: {:?}",
1409 f.filename,
1410 parse_result.err()
1411 );
1412 }
1413 }
1414
1415 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1418 TableInfo {
1419 schema_name: "public".to_string(),
1420 name: name.to_string(),
1421 columns,
1422 }
1423 }
1424
1425 #[test]
1426 fn test_generate_one_view() {
1427 let schema = SchemaInfo {
1428 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1429 ..Default::default()
1430 };
1431 let files = generate(
1432 &schema,
1433 DatabaseKind::Postgres,
1434 &[],
1435 &HashMap::new(),
1436 false,
1437 TimeCrate::Chrono,
1438 )
1439 .unwrap();
1440 assert_eq!(files.len(), 1);
1441 assert_eq!(files[0].filename, "active_users.rs");
1442 }
1443
1444 #[test]
1445 fn test_generate_no_origin_for_views() {
1446 let schema = SchemaInfo {
1447 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1448 ..Default::default()
1449 };
1450 let files = generate(
1451 &schema,
1452 DatabaseKind::Postgres,
1453 &[],
1454 &HashMap::new(),
1455 false,
1456 TimeCrate::Chrono,
1457 )
1458 .unwrap();
1459 assert_eq!(files[0].origin, None);
1460 }
1461
1462 #[test]
1463 fn test_generate_tables_and_views() {
1464 let schema = SchemaInfo {
1465 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1466 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1467 ..Default::default()
1468 };
1469 let files = generate(
1470 &schema,
1471 DatabaseKind::Postgres,
1472 &[],
1473 &HashMap::new(),
1474 false,
1475 TimeCrate::Chrono,
1476 )
1477 .unwrap();
1478 assert_eq!(files.len(), 2);
1479 }
1480
1481 #[test]
1482 fn test_generate_view_valid_rust() {
1483 let schema = SchemaInfo {
1484 views: vec![make_view(
1485 "active_users",
1486 vec![make_col("id", "int4"), make_col("name", "text")],
1487 )],
1488 ..Default::default()
1489 };
1490 let files = generate(
1491 &schema,
1492 DatabaseKind::Postgres,
1493 &[],
1494 &HashMap::new(),
1495 false,
1496 TimeCrate::Chrono,
1497 )
1498 .unwrap();
1499 let parse_result = syn::parse_file(&files[0].code);
1500 assert!(
1501 parse_result.is_ok(),
1502 "Failed to parse: {:?}",
1503 parse_result.err()
1504 );
1505 }
1506
1507 #[test]
1508 fn test_generate_view_nullable_column() {
1509 let schema = SchemaInfo {
1510 views: vec![make_view(
1511 "v",
1512 vec![ColumnInfo {
1513 name: "email".to_string(),
1514 data_type: "text".to_string(),
1515 udt_name: "text".to_string(),
1516 is_nullable: true,
1517 is_primary_key: false,
1518 ordinal_position: 0,
1519 schema_name: "public".to_string(),
1520 udt_schema: None,
1521 column_default: None,
1522 }],
1523 )],
1524 ..Default::default()
1525 };
1526 let files = generate(
1527 &schema,
1528 DatabaseKind::Postgres,
1529 &[],
1530 &HashMap::new(),
1531 false,
1532 TimeCrate::Chrono,
1533 )
1534 .unwrap();
1535 assert!(files[0].code.contains("Option<String>"));
1536 }
1537
1538 #[test]
1539 fn test_generate_collision_both_prefixed() {
1540 let schema = SchemaInfo {
1541 tables: vec![
1542 make_table("users", vec![make_col("id", "int4")]),
1543 TableInfo {
1544 schema_name: "billing".to_string(),
1545 name: "users".to_string(),
1546 columns: vec![make_col("id", "int4")],
1547 },
1548 ],
1549 ..Default::default()
1550 };
1551 let files = generate(
1552 &schema,
1553 DatabaseKind::Postgres,
1554 &[],
1555 &HashMap::new(),
1556 false,
1557 TimeCrate::Chrono,
1558 )
1559 .unwrap();
1560 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1561 assert!(filenames.contains(&"users.rs"));
1562 assert!(filenames.contains(&"billing_users.rs"));
1563 }
1564
1565 #[test]
1566 fn test_generate_no_collision_no_prefix() {
1567 let schema = SchemaInfo {
1568 tables: vec![
1569 make_table("users", vec![make_col("id", "int4")]),
1570 TableInfo {
1571 schema_name: "billing".to_string(),
1572 name: "invoices".to_string(),
1573 columns: vec![make_col("id", "int4")],
1574 },
1575 ],
1576 ..Default::default()
1577 };
1578 let files = generate(
1579 &schema,
1580 DatabaseKind::Postgres,
1581 &[],
1582 &HashMap::new(),
1583 false,
1584 TimeCrate::Chrono,
1585 )
1586 .unwrap();
1587 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1588 assert!(filenames.contains(&"users.rs"));
1589 assert!(filenames.contains(&"invoices.rs"));
1590 }
1591
1592 #[test]
1593 fn test_generate_single_schema_no_prefix() {
1594 let schema = SchemaInfo {
1595 tables: vec![
1596 make_table("users", vec![make_col("id", "int4")]),
1597 make_table("posts", vec![make_col("id", "int4")]),
1598 ],
1599 ..Default::default()
1600 };
1601 let files = generate(
1602 &schema,
1603 DatabaseKind::Postgres,
1604 &[],
1605 &HashMap::new(),
1606 false,
1607 TimeCrate::Chrono,
1608 )
1609 .unwrap();
1610 assert_eq!(files[0].filename, "users.rs");
1611 assert_eq!(files[1].filename, "posts.rs");
1612 }
1613
1614 #[test]
1615 fn test_generate_view_single_file_mode() {
1616 let schema = SchemaInfo {
1617 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1618 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1619 ..Default::default()
1620 };
1621 let files = generate(
1622 &schema,
1623 DatabaseKind::Postgres,
1624 &[],
1625 &HashMap::new(),
1626 true,
1627 TimeCrate::Chrono,
1628 )
1629 .unwrap();
1630 assert_eq!(files.len(), 2);
1631 }
1632
1633 #[test]
1636 fn test_parse_pg_enum_default_simple() {
1637 assert_eq!(
1638 parse_pg_enum_default("'idle'::task_status"),
1639 Some("idle".to_string())
1640 );
1641 }
1642
1643 #[test]
1644 fn test_parse_pg_enum_default_schema_qualified() {
1645 assert_eq!(
1646 parse_pg_enum_default("'active'::public.task_status"),
1647 Some("active".to_string())
1648 );
1649 }
1650
1651 #[test]
1652 fn test_parse_pg_enum_default_not_enum() {
1653 assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1655 }
1656
1657 #[test]
1658 fn test_parse_pg_enum_default_no_cast() {
1659 assert_eq!(parse_pg_enum_default("'hello'"), None);
1660 }
1661
1662 #[test]
1663 fn test_parse_pg_enum_default_empty() {
1664 assert_eq!(parse_pg_enum_default(""), None);
1665 }
1666
1667 #[test]
1670 fn test_extract_enum_defaults_from_column() {
1671 let schema = SchemaInfo {
1672 tables: vec![TableInfo {
1673 schema_name: "public".to_string(),
1674 name: "tasks".to_string(),
1675 columns: vec![ColumnInfo {
1676 name: "status".to_string(),
1677 data_type: "USER-DEFINED".to_string(),
1678 udt_name: "task_status".to_string(),
1679 is_nullable: false,
1680 is_primary_key: false,
1681 ordinal_position: 0,
1682 schema_name: "public".to_string(),
1683 udt_schema: None,
1684 column_default: Some("'idle'::task_status".to_string()),
1685 }],
1686 }],
1687 enums: vec![EnumInfo {
1688 schema_name: "public".to_string(),
1689 name: "task_status".to_string(),
1690 variants: vec!["idle".to_string(), "running".to_string()],
1691 default_variant: None,
1692 }],
1693 ..Default::default()
1694 };
1695 let defaults = extract_enum_defaults(&schema);
1696 assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1697 }
1698
1699 #[test]
1700 fn test_extract_enum_defaults_no_default() {
1701 let schema = SchemaInfo {
1702 tables: vec![TableInfo {
1703 schema_name: "public".to_string(),
1704 name: "tasks".to_string(),
1705 columns: vec![ColumnInfo {
1706 name: "status".to_string(),
1707 data_type: "USER-DEFINED".to_string(),
1708 udt_name: "task_status".to_string(),
1709 is_nullable: false,
1710 is_primary_key: false,
1711 ordinal_position: 0,
1712 schema_name: "public".to_string(),
1713 udt_schema: None,
1714 column_default: None,
1715 }],
1716 }],
1717 enums: vec![EnumInfo {
1718 schema_name: "public".to_string(),
1719 name: "task_status".to_string(),
1720 variants: vec!["idle".to_string()],
1721 default_variant: None,
1722 }],
1723 ..Default::default()
1724 };
1725 let defaults = extract_enum_defaults(&schema);
1726 assert!(defaults.is_empty());
1727 }
1728
1729 #[test]
1730 fn test_extract_enum_defaults_non_enum_column_ignored() {
1731 let schema = SchemaInfo {
1732 tables: vec![TableInfo {
1733 schema_name: "public".to_string(),
1734 name: "users".to_string(),
1735 columns: vec![ColumnInfo {
1736 name: "name".to_string(),
1737 data_type: "character varying".to_string(),
1738 udt_name: "varchar".to_string(),
1739 is_nullable: false,
1740 is_primary_key: false,
1741 ordinal_position: 0,
1742 schema_name: "public".to_string(),
1743 udt_schema: None,
1744 column_default: Some("'hello'::character varying".to_string()),
1745 }],
1746 }],
1747 enums: vec![],
1748 ..Default::default()
1749 };
1750 let defaults = extract_enum_defaults(&schema);
1751 assert!(defaults.is_empty());
1752 }
1753
1754 #[test]
1755 fn test_generate_enum_with_default() {
1756 let schema = SchemaInfo {
1757 tables: vec![TableInfo {
1758 schema_name: "public".to_string(),
1759 name: "tasks".to_string(),
1760 columns: vec![ColumnInfo {
1761 name: "status".to_string(),
1762 data_type: "USER-DEFINED".to_string(),
1763 udt_name: "task_status".to_string(),
1764 is_nullable: false,
1765 is_primary_key: false,
1766 ordinal_position: 0,
1767 schema_name: "public".to_string(),
1768 udt_schema: None,
1769 column_default: Some("'idle'::task_status".to_string()),
1770 }],
1771 }],
1772 enums: vec![EnumInfo {
1773 schema_name: "public".to_string(),
1774 name: "task_status".to_string(),
1775 variants: vec!["idle".to_string(), "running".to_string()],
1776 default_variant: None,
1777 }],
1778 ..Default::default()
1779 };
1780 let files = generate(
1781 &schema,
1782 DatabaseKind::Postgres,
1783 &[],
1784 &HashMap::new(),
1785 false,
1786 TimeCrate::Chrono,
1787 )
1788 .unwrap();
1789 let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1790 assert!(types_file.code.contains("impl Default for TaskStatus"));
1791 assert!(types_file.code.contains("Self::Idle"));
1792 }
1793}