Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod identifiers;
7pub mod struct_gen;
8
9use std::collections::{BTreeSet, HashMap};
10use std::path::Path;
11
12use proc_macro2::TokenStream;
13
14use crate::cli::{DatabaseKind, TimeCrate};
15use crate::introspect::SchemaInfo;
16
17/// Rust reserved keywords that cannot be used as identifiers.
18const RUST_KEYWORDS: &[&str] = &[
19    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
20    "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub",
21    "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type",
22    "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", "final",
23    "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
24];
25
26/// Returns true if the given name is a Rust reserved keyword.
27pub fn is_rust_keyword(name: &str) -> bool {
28    RUST_KEYWORDS.contains(&name)
29}
30
31/// Returns the imports needed for well-known extra derives.
32pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
33    let mut imports = Vec::new();
34    let has = |name: &str| extra_derives.iter().any(|d| d == name);
35    if has("Serialize") || has("Deserialize") {
36        let mut parts = Vec::new();
37        if has("Serialize") {
38            parts.push("Serialize");
39        }
40        if has("Deserialize") {
41            parts.push("Deserialize");
42        }
43        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
44    }
45    imports
46}
47
48/// Normalize a table name for use as a Rust module/filename:
49/// replace multiple consecutive underscores with a single one.
50pub fn normalize_module_name(name: &str) -> String {
51    let mut result = String::with_capacity(name.len());
52    let mut prev_underscore = false;
53    for c in name.chars() {
54        if c == '_' {
55            if !prev_underscore {
56                result.push(c);
57            }
58            prev_underscore = true;
59        } else {
60            prev_underscore = false;
61            result.push(c);
62        }
63    }
64    result
65}
66
67/// Well-known default schemas that don't need a prefix in filenames.
68const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
69
70/// Returns true if the schema is a well-known default (public, main, dbo).
71pub fn is_default_schema(schema: &str) -> bool {
72    DEFAULT_SCHEMAS.contains(&schema)
73}
74
75/// Compute the Rust identifier for an enum / composite / domain.
76///
77/// When the same SQL name is declared in more than one schema (e.g. both
78/// `auth.role` and `billing.role` exist), the non-default-schema variants get
79/// a `SchemaName` PascalCase prefix to avoid Rust-level identifier collisions.
80/// Otherwise the bare PascalCase of the SQL name is used.
81pub fn rust_type_name_for(schema_info: &SchemaInfo, schema: &str, name: &str) -> String {
82    use heck::ToUpperCamelCase;
83    if type_name_has_cross_schema_collision(schema_info, name) && !is_default_schema(schema) {
84        format!(
85            "{}{}",
86            schema.to_upper_camel_case(),
87            name.to_upper_camel_case()
88        )
89    } else {
90        name.to_upper_camel_case()
91    }
92}
93
94/// Compute the schemas that must appear in PostgreSQL's `search_path` for
95/// the generated code to resolve every emitted unqualified `type_name`.
96///
97/// Returns the deduplicated, sorted list of non-default schemas hosting
98/// enums/composites/domains in `schema_info`. The caller can feed this into
99/// the pool's connect-hook, e.g.:
100///
101/// ```ignore
102/// let schemas = sqlx_gen::codegen::required_pg_search_path(&info).join(", ");
103/// sqlx::query(&format!("SET search_path TO public, {}", schemas))
104///     .execute(&pool).await?;
105/// ```
106pub fn required_pg_search_path(schema_info: &SchemaInfo) -> Vec<String> {
107    let mut schemas: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
108    for e in &schema_info.enums {
109        if !is_default_schema(&e.schema_name) {
110            schemas.insert(e.schema_name.clone());
111        }
112    }
113    for c in &schema_info.composite_types {
114        if !is_default_schema(&c.schema_name) {
115            schemas.insert(c.schema_name.clone());
116        }
117    }
118    for d in &schema_info.domains {
119        if !is_default_schema(&d.schema_name) {
120            schemas.insert(d.schema_name.clone());
121        }
122    }
123    schemas.into_iter().collect()
124}
125
126/// True when the SQL `name` is declared by enums / composites / domains living
127/// in more than one schema.
128pub fn type_name_has_cross_schema_collision(schema_info: &SchemaInfo, name: &str) -> bool {
129    let mut schemas: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
130    schemas.extend(
131        schema_info
132            .enums
133            .iter()
134            .filter(|e| e.name == name)
135            .map(|e| e.schema_name.as_str()),
136    );
137    schemas.extend(
138        schema_info
139            .composite_types
140            .iter()
141            .filter(|c| c.name == name)
142            .map(|c| c.schema_name.as_str()),
143    );
144    schemas.extend(
145        schema_info
146            .domains
147            .iter()
148            .filter(|d| d.name == name)
149            .map(|d| d.schema_name.as_str()),
150    );
151    schemas.len() > 1
152}
153
154/// Build a module name, prefixing with schema only when the name collides
155/// (same table name exists in multiple schemas).
156pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
157    if name_collides && !is_default_schema(schema_name) {
158        normalize_module_name(&format!("{}_{}", schema_name, table_name))
159    } else {
160        normalize_module_name(table_name)
161    }
162}
163
164/// Find table/view names that appear in more than one schema.
165fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
166    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
167    for t in &schema_info.tables {
168        seen.entry(t.name.as_str())
169            .or_default()
170            .insert(t.schema_name.as_str());
171    }
172    for v in &schema_info.views {
173        seen.entry(v.name.as_str())
174            .or_default()
175            .insert(v.schema_name.as_str());
176    }
177    seen.into_iter()
178        .filter(|(_, schemas)| schemas.len() > 1)
179        .map(|(name, _)| name)
180        .collect()
181}
182
183/// A generated code file with its content and required imports.
184#[derive(Debug, Clone)]
185pub struct GeneratedFile {
186    pub filename: String,
187    /// Optional origin comment (e.g. "Table: schema.name")
188    pub origin: Option<String>,
189    pub code: String,
190}
191
192/// Generate all code for a given schema.
193pub fn generate(
194    schema_info: &SchemaInfo,
195    db_kind: DatabaseKind,
196    extra_derives: &[String],
197    type_overrides: &HashMap<String, String>,
198    single_file: bool,
199    time_crate: TimeCrate,
200) -> crate::error::Result<Vec<GeneratedFile>> {
201    generate_with_domain_style(
202        schema_info,
203        db_kind,
204        extra_derives,
205        type_overrides,
206        single_file,
207        time_crate,
208        crate::cli::DomainStyle::Alias,
209    )
210}
211
212/// Same as [`generate`] but lets the caller pick how Postgres domains are
213/// rendered (alias vs newtype).
214pub fn generate_with_domain_style(
215    schema_info: &SchemaInfo,
216    db_kind: DatabaseKind,
217    extra_derives: &[String],
218    type_overrides: &HashMap<String, String>,
219    single_file: bool,
220    time_crate: TimeCrate,
221    domain_style: crate::cli::DomainStyle,
222) -> crate::error::Result<Vec<GeneratedFile>> {
223    let mut files = Vec::new();
224
225    // Detect table/view names that appear in multiple schemas (collisions)
226    let colliding_names = find_colliding_names(schema_info);
227
228    // Generate struct files for each table
229    for table in &schema_info.tables {
230        let (tokens, imports) = struct_gen::generate_struct(
231            table,
232            db_kind,
233            schema_info,
234            extra_derives,
235            type_overrides,
236            false,
237            time_crate,
238        );
239        let imports = filter_imports(&imports, single_file);
240        let code = format_tokens_with_imports(&tokens, &imports)?;
241        let module_name = build_module_name(
242            &table.schema_name,
243            &table.name,
244            colliding_names.contains(table.name.as_str()),
245        );
246        files.push(GeneratedFile {
247            filename: format!("{}.rs", module_name),
248            origin: None,
249            code,
250        });
251    }
252
253    // Generate struct files for each view
254    for view in &schema_info.views {
255        let (tokens, imports) = struct_gen::generate_struct(
256            view,
257            db_kind,
258            schema_info,
259            extra_derives,
260            type_overrides,
261            true,
262            time_crate,
263        );
264        let imports = filter_imports(&imports, single_file);
265        let code = format_tokens_with_imports(&tokens, &imports)?;
266        let module_name = build_module_name(
267            &view.schema_name,
268            &view.name,
269            colliding_names.contains(view.name.as_str()),
270        );
271        files.push(GeneratedFile {
272            filename: format!("{}.rs", module_name),
273            origin: None,
274            code,
275        });
276    }
277
278    // Generate types file (enums, composites, domains)
279    // Each item is formatted individually so we can insert blank lines between them.
280    let mut types_blocks: Vec<String> = Vec::new();
281    let mut types_imports = BTreeSet::new();
282
283    // Enrich enums with default variants extracted from column defaults
284    let enum_defaults = extract_enum_defaults(schema_info);
285    for enum_info in &schema_info.enums {
286        enum_gen::check_variant_collisions(enum_info)?;
287        let mut enriched = enum_info.clone();
288        if enriched.default_variant.is_none() {
289            if let Some(default) = enum_defaults.get(&enum_info.name) {
290                enriched.default_variant = Some(default.clone());
291            }
292        }
293        let (tokens, imports) =
294            enum_gen::generate_enum_with_schema(&enriched, db_kind, extra_derives, schema_info);
295        types_blocks.push(format_tokens(&tokens)?);
296        types_imports.extend(imports);
297    }
298
299    for composite in &schema_info.composite_types {
300        let (tokens, imports) = composite_gen::generate_composite(
301            composite,
302            db_kind,
303            schema_info,
304            extra_derives,
305            type_overrides,
306            time_crate,
307        );
308        types_blocks.push(format_tokens(&tokens)?);
309        types_imports.extend(imports);
310    }
311
312    for domain in &schema_info.domains {
313        let (tokens, imports) = domain_gen::generate_domain_with_style(
314            domain,
315            db_kind,
316            schema_info,
317            type_overrides,
318            time_crate,
319            domain_style,
320        );
321        types_blocks.push(format_tokens(&tokens)?);
322        types_imports.extend(imports);
323    }
324
325    if !types_blocks.is_empty() {
326        let import_lines: String = types_imports.iter().map(|i| format!("{}\n", i)).collect();
327        let body = types_blocks.join("\n");
328        let code = if import_lines.is_empty() {
329            body
330        } else {
331            format!("{}\n\n{}", import_lines.trim_end(), body)
332        };
333        files.push(GeneratedFile {
334            filename: "types.rs".to_string(),
335            origin: None,
336            code,
337        });
338    }
339
340    Ok(files)
341}
342
343/// Extract default variant values for enums by scanning column defaults across all tables and views.
344/// PostgreSQL column defaults look like `'idle'::task_status` or `'active'::public.task_status`.
345fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
346    let mut defaults: HashMap<String, String> = HashMap::new();
347
348    let all_columns = schema_info
349        .tables
350        .iter()
351        .chain(schema_info.views.iter())
352        .flat_map(|t| t.columns.iter());
353
354    for col in all_columns {
355        let default_expr = match &col.column_default {
356            Some(d) => d,
357            None => continue,
358        };
359
360        // Strip leading underscore for array types to get the base enum name
361        let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
362
363        // Check if this column references a known enum
364        let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
365        if enum_match.is_none() {
366            continue;
367        }
368
369        // Parse PG default: 'variant'::type_name
370        if let Some(variant) = parse_pg_enum_default(default_expr) {
371            defaults.entry(base_udt.to_string()).or_insert(variant);
372        }
373    }
374
375    defaults
376}
377
378/// Parse a PostgreSQL column default expression to extract the enum variant.
379/// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`.
380fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
381    // Pattern: 'value'::some_type
382    let after_opening = default_expr.trim().strip_prefix('\'')?;
383    let end_quote = after_opening.find('\'')?;
384    let value = &after_opening[..end_quote];
385    let rest = &after_opening[end_quote + 1..];
386    if rest.starts_with("::") {
387        return Some(value.to_string());
388    }
389    None
390}
391
392/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
393fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
394    if single_file {
395        imports
396            .iter()
397            .filter(|i| !i.contains("super::types::"))
398            .cloned()
399            .collect()
400    } else {
401        imports.clone()
402    }
403}
404
405/// Detect `tab_spaces` from `rustfmt.toml` or `.rustfmt.toml` by walking up
406/// from `start_dir`. Returns 4 (rustfmt default) if no config is found.
407pub fn detect_tab_spaces(start_dir: &Path) -> usize {
408    let mut dir = if start_dir.is_file() {
409        start_dir.parent().unwrap_or(start_dir)
410    } else {
411        start_dir
412    };
413    loop {
414        for name in &["rustfmt.toml", ".rustfmt.toml"] {
415            let candidate = dir.join(name);
416            if let Ok(content) = std::fs::read_to_string(&candidate) {
417                for line in content.lines() {
418                    let line = line.trim();
419                    if let Some(rest) = line.strip_prefix("tab_spaces") {
420                        let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
421                        if let Ok(n) = rest.trim().parse::<usize>() {
422                            return n;
423                        }
424                    }
425                }
426                // Config found but no tab_spaces key → use rustfmt default
427                return 4;
428            }
429        }
430        match dir.parent() {
431            Some(parent) => dir = parent,
432            None => return 4,
433        }
434    }
435}
436
437/// Parse and format a TokenStream via prettyplease, then post-process spacing.
438/// `tab_spaces` controls how many spaces per indentation level for SQL inside raw strings.
439pub(crate) fn parse_and_format(tokens: &TokenStream) -> crate::error::Result<String> {
440    parse_and_format_with_tab_spaces(tokens, 4)
441}
442
443pub(crate) fn parse_and_format_with_tab_spaces(
444    tokens: &TokenStream,
445    tab_spaces: usize,
446) -> crate::error::Result<String> {
447    let file = syn::parse2::<syn::File>(tokens.clone()).map_err(|e| {
448        crate::error::Error::Config(format!(
449            "Internal sqlx-gen bug: failed to parse generated code: {}. \
450             Raw tokens:\n  {}\n\
451             Please report this with the input schema.",
452            e, tokens
453        ))
454    })?;
455    let raw = prettyplease::unparse(&file);
456    let raw = indent_multiline_raw_strings(&raw, tab_spaces);
457    Ok(add_blank_lines_between_items(&raw))
458}
459
460/// Format a single TokenStream block (no imports).
461pub(crate) fn format_tokens(tokens: &TokenStream) -> crate::error::Result<String> {
462    parse_and_format(tokens)
463}
464
465pub fn format_tokens_with_imports(
466    tokens: &TokenStream,
467    imports: &BTreeSet<String>,
468) -> crate::error::Result<String> {
469    format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
470}
471
472pub fn format_tokens_with_imports_and_tab_spaces(
473    tokens: &TokenStream,
474    imports: &BTreeSet<String>,
475    tab_spaces: usize,
476) -> crate::error::Result<String> {
477    let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces)?;
478
479    let used_imports: Vec<&String> = imports
480        .iter()
481        .filter(|imp| is_import_used(imp, &formatted))
482        .collect();
483
484    if used_imports.is_empty() {
485        Ok(formatted)
486    } else {
487        let import_lines: String = used_imports.iter().map(|i| format!("{}\n", i)).collect();
488        Ok(format!("{}\n\n{}", import_lines.trim_end(), formatted))
489    }
490}
491
492/// Check if an import is actually used in the generated code.
493/// Extracts the imported type names and checks if they appear in the code.
494fn is_import_used(import: &str, code: &str) -> bool {
495    // "use foo::bar::Baz;" → check for "Baz"
496    // "use foo::{A, B};" → check for "A" or "B"
497    // "use foo::bar::*;" → always keep
498    let trimmed = import.trim().trim_end_matches(';');
499    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
500
501    if path.ends_with("::*") {
502        return true;
503    }
504
505    // Handle grouped imports: use foo::{A, B, C};
506    if let Some(start) = path.find('{') {
507        if let Some(end) = path.find('}') {
508            let names = &path[start + 1..end];
509            return names
510                .split(',')
511                .map(|n| n.trim())
512                .filter(|n| !n.is_empty())
513                .any(|name| code.contains(name));
514        }
515    }
516
517    // Simple import: use foo::Bar;
518    if let Some(name) = path.rsplit("::").next() {
519        return code.contains(name);
520    }
521
522    true
523}
524
525/// Indent the content of multi-line raw string literals (`r#"..."#`) so SQL
526/// reads naturally in generated code. All SQL raw strings live inside `impl`
527/// methods, so content is indented at a fixed 2-level depth and relative
528/// indentation between SQL lines is preserved (e.g. SET items indented under
529/// the SET keyword).
530fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
531    // Raw string content is NOT reformatted by external formatters, so we must
532    // bake the right indentation at generation time.
533    // The closing "# aligns with the r#" argument level (3 indent levels deep),
534    // and SQL content gets one extra level beyond that.
535    let close_indent = 4 + tab_spaces; // impl(4) + fn_arg(tab)
536    let sql_indent = 4 + 2 * tab_spaces; // impl(4) + fn_arg(tab) + sql(tab)
537
538    let lines: Vec<&str> = code.lines().collect();
539    let mut result = Vec::with_capacity(lines.len());
540    let mut inside_raw = false;
541    let mut raw_lines: Vec<&str> = Vec::new();
542
543    for line in &lines {
544        if !inside_raw {
545            if let Some(pos) = line.find("r#\"") {
546                let after = &line[pos + 3..];
547                if !after.contains("\"#") {
548                    inside_raw = true;
549                    raw_lines.clear();
550                }
551            }
552            result.push(line.to_string());
553        } else if line.trim_start().starts_with("\"#") {
554            // Find minimum indentation among non-empty content lines
555            let min_indent = raw_lines
556                .iter()
557                .filter(|l| !l.trim().is_empty())
558                .map(|l| l.len() - l.trim_start().len())
559                .min()
560                .unwrap_or(0);
561            for raw_line in &raw_lines {
562                let trimmed = raw_line.trim();
563                if trimmed.is_empty() {
564                    result.push(String::new());
565                } else {
566                    let original_indent = raw_line.len() - raw_line.trim_start().len();
567                    let relative = original_indent.saturating_sub(min_indent);
568                    result.push(format!(
569                        "{}{}{}",
570                        " ".repeat(sql_indent),
571                        " ".repeat(relative),
572                        trimmed
573                    ));
574                }
575            }
576            // Closing "# at method body level
577            let trimmed = line.trim();
578            result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
579            inside_raw = false;
580        } else {
581            raw_lines.push(line);
582        }
583    }
584
585    result.join("\n")
586}
587
588fn add_blank_lines_between_items(code: &str) -> String {
589    let lines: Vec<&str> = code.lines().collect();
590    let mut result = Vec::with_capacity(lines.len());
591
592    for (i, line) in lines.iter().enumerate() {
593        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
594        // but not for the first variant in the enum.
595        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
596            let prev = lines[i - 1].trim();
597            if prev.ends_with(',') {
598                result.push("");
599            }
600        }
601
602        // Insert a blank line before top-level items (pub struct, impl, #[derive)
603        // and before methods inside impl blocks, when preceded by a closing brace `}`
604        if i > 0 {
605            let trimmed = line.trim();
606            let prev = lines[i - 1].trim();
607            if prev == "}"
608                && (trimmed.starts_with("pub struct")
609                    || trimmed.starts_with("impl ")
610                    || trimmed.starts_with("#[derive")
611                    || trimmed.starts_with("pub async fn")
612                    || trimmed.starts_with("pub fn"))
613            {
614                result.push("");
615            }
616        }
617
618        // Insert a blank line before a new logical block inside methods:
619        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
620        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
621        if i > 0 {
622            let trimmed = line.trim();
623            let prev = lines[i - 1].trim();
624            let prev_is_await_end = prev.ends_with(".await?;")
625                || prev.ends_with(".await?")
626                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
627            if prev_is_await_end && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) {
628                result.push("");
629            }
630            // Separate a sqlx query `let` from preceding simple `let` assignments
631            if trimmed.starts_with("let ")
632                && trimmed.contains("sqlx::")
633                && prev.starts_with("let ")
634                && !prev.contains("sqlx::")
635            {
636                result.push("");
637            }
638        }
639
640        result.push(line);
641    }
642
643    result.join("\n")
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use crate::introspect::{
650        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
651    };
652    use std::collections::HashMap;
653
654    // ========== is_rust_keyword ==========
655
656    #[test]
657    fn test_keyword_type() {
658        assert!(is_rust_keyword("type"));
659    }
660
661    #[test]
662    fn test_keyword_fn() {
663        assert!(is_rust_keyword("fn"));
664    }
665
666    #[test]
667    fn test_keyword_let() {
668        assert!(is_rust_keyword("let"));
669    }
670
671    #[test]
672    fn test_keyword_match() {
673        assert!(is_rust_keyword("match"));
674    }
675
676    #[test]
677    fn test_keyword_async() {
678        assert!(is_rust_keyword("async"));
679    }
680
681    #[test]
682    fn test_keyword_await() {
683        assert!(is_rust_keyword("await"));
684    }
685
686    #[test]
687    fn test_keyword_yield() {
688        assert!(is_rust_keyword("yield"));
689    }
690
691    #[test]
692    fn test_keyword_abstract() {
693        assert!(is_rust_keyword("abstract"));
694    }
695
696    #[test]
697    fn test_keyword_try() {
698        assert!(is_rust_keyword("try"));
699    }
700
701    #[test]
702    fn test_not_keyword_name() {
703        assert!(!is_rust_keyword("name"));
704    }
705
706    #[test]
707    fn test_not_keyword_id() {
708        assert!(!is_rust_keyword("id"));
709    }
710
711    #[test]
712    fn test_not_keyword_uppercase_type() {
713        assert!(!is_rust_keyword("Type"));
714    }
715
716    // ========== normalize_module_name ==========
717
718    #[test]
719    fn test_normalize_no_underscores() {
720        assert_eq!(normalize_module_name("users"), "users");
721    }
722
723    #[test]
724    fn test_normalize_single_underscore() {
725        assert_eq!(normalize_module_name("user_roles"), "user_roles");
726    }
727
728    #[test]
729    fn test_normalize_double_underscore() {
730        assert_eq!(normalize_module_name("user__roles"), "user_roles");
731    }
732
733    #[test]
734    fn test_normalize_triple_underscore() {
735        assert_eq!(normalize_module_name("a___b"), "a_b");
736    }
737
738    #[test]
739    fn test_normalize_leading_underscore() {
740        assert_eq!(normalize_module_name("_private"), "_private");
741    }
742
743    #[test]
744    fn test_normalize_trailing_underscore() {
745        assert_eq!(normalize_module_name("name_"), "name_");
746    }
747
748    #[test]
749    fn test_normalize_double_leading() {
750        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
751    }
752
753    #[test]
754    fn test_normalize_multiple_groups() {
755        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
756    }
757
758    // ========== build_module_name ==========
759
760    #[test]
761    fn test_build_no_collision_no_prefix() {
762        assert_eq!(build_module_name("public", "users", false), "users");
763    }
764
765    #[test]
766    fn test_build_no_collision_non_default_no_prefix() {
767        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
768    }
769
770    #[test]
771    fn test_build_collision_prefixed() {
772        assert_eq!(build_module_name("billing", "users", true), "billing_users");
773    }
774
775    #[test]
776    fn test_build_collision_default_schema_no_prefix() {
777        assert_eq!(build_module_name("public", "users", true), "users");
778    }
779
780    #[test]
781    fn test_build_collision_normalizes_double_underscore() {
782        assert_eq!(
783            build_module_name("billing", "agent__connector", true),
784            "billing_agent_connector"
785        );
786    }
787
788    // ========== is_default_schema ==========
789
790    #[test]
791    fn test_default_schema_public() {
792        assert!(is_default_schema("public"));
793    }
794
795    #[test]
796    fn test_default_schema_main() {
797        assert!(is_default_schema("main"));
798    }
799
800    #[test]
801    fn test_non_default_schema() {
802        assert!(!is_default_schema("billing"));
803    }
804
805    // ========== imports_for_derives ==========
806
807    #[test]
808    fn test_imports_empty() {
809        let result = imports_for_derives(&[]);
810        assert!(result.is_empty());
811    }
812
813    #[test]
814    fn test_imports_serialize_only() {
815        let derives = vec!["Serialize".to_string()];
816        let result = imports_for_derives(&derives);
817        assert_eq!(result, vec!["use serde::{Serialize};"]);
818    }
819
820    #[test]
821    fn test_imports_deserialize_only() {
822        let derives = vec!["Deserialize".to_string()];
823        let result = imports_for_derives(&derives);
824        assert_eq!(result, vec!["use serde::{Deserialize};"]);
825    }
826
827    #[test]
828    fn test_imports_both_serde() {
829        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
830        let result = imports_for_derives(&derives);
831        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
832    }
833
834    #[test]
835    fn test_imports_non_serde() {
836        let derives = vec!["Hash".to_string()];
837        let result = imports_for_derives(&derives);
838        assert!(result.is_empty());
839    }
840
841    #[test]
842    fn test_imports_non_serde_multiple() {
843        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
844        let result = imports_for_derives(&derives);
845        assert!(result.is_empty());
846    }
847
848    #[test]
849    fn test_imports_mixed_serde_and_others() {
850        let derives = vec![
851            "Serialize".to_string(),
852            "Hash".to_string(),
853            "Deserialize".to_string(),
854        ];
855        let result = imports_for_derives(&derives);
856        assert_eq!(result.len(), 1);
857        assert!(result[0].contains("Serialize"));
858        assert!(result[0].contains("Deserialize"));
859    }
860
861    // ========== add_blank_lines_between_items ==========
862
863    #[test]
864    fn test_blank_lines_between_renamed_variants() {
865        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
866        let result = add_blank_lines_between_items(input);
867        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
868    }
869
870    #[test]
871    fn test_no_blank_line_for_first_variant() {
872        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
873        let result = add_blank_lines_between_items(input);
874        // No blank line before first #[sqlx(rename because previous line is `{`
875        assert!(!result.contains("{\n\n"));
876    }
877
878    #[test]
879    fn test_no_change_without_rename() {
880        let input = "pub enum Foo {\n    A,\n    B,\n}";
881        let result = add_blank_lines_between_items(input);
882        assert_eq!(result, input);
883    }
884
885    #[test]
886    fn test_no_change_for_struct() {
887        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
888        let result = add_blank_lines_between_items(input);
889        assert_eq!(result, input);
890    }
891
892    // ========== rust_type_name_for / cross-schema collisions ==========
893
894    fn schema_with_two_role_enums() -> SchemaInfo {
895        SchemaInfo {
896            enums: vec![
897                crate::introspect::EnumInfo {
898                    schema_name: "auth".into(),
899                    name: "role".into(),
900                    variants: vec!["admin".into(), "user".into()],
901                    default_variant: None,
902                },
903                crate::introspect::EnumInfo {
904                    schema_name: "billing".into(),
905                    name: "role".into(),
906                    variants: vec!["payer".into(), "payee".into()],
907                    default_variant: None,
908                },
909            ],
910            ..Default::default()
911        }
912    }
913
914    #[test]
915    fn rust_type_name_prefixes_schema_on_cross_schema_collision() {
916        let s = schema_with_two_role_enums();
917        assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
918        assert_eq!(rust_type_name_for(&s, "billing", "role"), "BillingRole");
919    }
920
921    #[test]
922    fn rust_type_name_keeps_bare_name_when_unique() {
923        let s = SchemaInfo {
924            enums: vec![crate::introspect::EnumInfo {
925                schema_name: "auth".into(),
926                name: "role".into(),
927                variants: vec!["admin".into()],
928                default_variant: None,
929            }],
930            ..Default::default()
931        };
932        assert_eq!(rust_type_name_for(&s, "auth", "role"), "Role");
933    }
934
935    #[test]
936    fn required_search_path_collects_non_default_schemas() {
937        let s = SchemaInfo {
938            enums: vec![
939                crate::introspect::EnumInfo {
940                    schema_name: "auth".into(),
941                    name: "role".into(),
942                    variants: vec!["x".into()],
943                    default_variant: None,
944                },
945                crate::introspect::EnumInfo {
946                    schema_name: "public".into(),
947                    name: "status".into(),
948                    variants: vec!["y".into()],
949                    default_variant: None,
950                },
951            ],
952            composite_types: vec![crate::introspect::CompositeTypeInfo {
953                schema_name: "billing".into(),
954                name: "addr".into(),
955                fields: vec![],
956            }],
957            domains: vec![crate::introspect::DomainInfo {
958                schema_name: "auth".into(),
959                name: "email".into(),
960                base_type: "text".into(),
961            }],
962            ..Default::default()
963        };
964        // Sorted, deduplicated, public excluded.
965        assert_eq!(required_pg_search_path(&s), vec!["auth", "billing"]);
966    }
967
968    #[test]
969    fn required_search_path_empty_when_only_default_schema() {
970        let s = SchemaInfo {
971            enums: vec![crate::introspect::EnumInfo {
972                schema_name: "public".into(),
973                name: "status".into(),
974                variants: vec!["y".into()],
975                default_variant: None,
976            }],
977            ..Default::default()
978        };
979        assert!(required_pg_search_path(&s).is_empty());
980    }
981
982    #[test]
983    fn rust_type_name_default_schema_keeps_bare_name_even_on_collision() {
984        let s = SchemaInfo {
985            enums: vec![
986                crate::introspect::EnumInfo {
987                    schema_name: "public".into(),
988                    name: "role".into(),
989                    variants: vec!["a".into()],
990                    default_variant: None,
991                },
992                crate::introspect::EnumInfo {
993                    schema_name: "auth".into(),
994                    name: "role".into(),
995                    variants: vec!["b".into()],
996                    default_variant: None,
997                },
998            ],
999            ..Default::default()
1000        };
1001        // public stays "Role"; auth gets the schema prefix to break the tie.
1002        assert_eq!(rust_type_name_for(&s, "public", "role"), "Role");
1003        assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
1004    }
1005
1006    // ========== filter_imports ==========
1007
1008    #[test]
1009    fn test_filter_single_file_strips_super_types() {
1010        let mut imports = BTreeSet::new();
1011        imports.insert("use super::types::Foo;".to_string());
1012        imports.insert("use chrono::NaiveDateTime;".to_string());
1013        let result = filter_imports(&imports, true);
1014        assert!(!result.contains("use super::types::Foo;"));
1015        assert!(result.contains("use chrono::NaiveDateTime;"));
1016    }
1017
1018    #[test]
1019    fn test_filter_single_file_keeps_other_imports() {
1020        let mut imports = BTreeSet::new();
1021        imports.insert("use chrono::NaiveDateTime;".to_string());
1022        let result = filter_imports(&imports, true);
1023        assert!(result.contains("use chrono::NaiveDateTime;"));
1024    }
1025
1026    #[test]
1027    fn test_filter_multi_file_keeps_all() {
1028        let mut imports = BTreeSet::new();
1029        imports.insert("use super::types::Foo;".to_string());
1030        imports.insert("use chrono::NaiveDateTime;".to_string());
1031        let result = filter_imports(&imports, false);
1032        assert_eq!(result.len(), 2);
1033    }
1034
1035    #[test]
1036    fn test_filter_empty_set() {
1037        let imports = BTreeSet::new();
1038        let result = filter_imports(&imports, true);
1039        assert!(result.is_empty());
1040    }
1041
1042    // ========== generate() orchestrator ==========
1043
1044    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1045        TableInfo {
1046            schema_name: "public".to_string(),
1047            name: name.to_string(),
1048            columns,
1049        }
1050    }
1051
1052    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
1053        ColumnInfo {
1054            name: name.to_string(),
1055            data_type: udt_name.to_string(),
1056            udt_name: udt_name.to_string(),
1057            is_nullable: false,
1058            is_primary_key: false,
1059            ordinal_position: 0,
1060            schema_name: "public".to_string(),
1061            udt_schema: None,
1062            column_default: None,
1063        }
1064    }
1065
1066    #[test]
1067    fn test_generate_empty_schema() {
1068        let schema = SchemaInfo::default();
1069        let files = generate(
1070            &schema,
1071            DatabaseKind::Postgres,
1072            &[],
1073            &HashMap::new(),
1074            false,
1075            TimeCrate::Chrono,
1076        )
1077        .unwrap();
1078        assert!(files.is_empty());
1079    }
1080
1081    #[test]
1082    fn test_generate_one_table() {
1083        let schema = SchemaInfo {
1084            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1085            ..Default::default()
1086        };
1087        let files = generate(
1088            &schema,
1089            DatabaseKind::Postgres,
1090            &[],
1091            &HashMap::new(),
1092            false,
1093            TimeCrate::Chrono,
1094        )
1095        .unwrap();
1096        assert_eq!(files.len(), 1);
1097        assert_eq!(files[0].filename, "users.rs");
1098    }
1099
1100    #[test]
1101    fn test_generate_two_tables() {
1102        let schema = SchemaInfo {
1103            tables: vec![
1104                make_table("users", vec![make_col("id", "int4")]),
1105                make_table("posts", vec![make_col("id", "int4")]),
1106            ],
1107            ..Default::default()
1108        };
1109        let files = generate(
1110            &schema,
1111            DatabaseKind::Postgres,
1112            &[],
1113            &HashMap::new(),
1114            false,
1115            TimeCrate::Chrono,
1116        )
1117        .unwrap();
1118        assert_eq!(files.len(), 2);
1119    }
1120
1121    #[test]
1122    fn test_generate_enum_creates_types_file() {
1123        let schema = SchemaInfo {
1124            enums: vec![EnumInfo {
1125                schema_name: "public".to_string(),
1126                name: "status".to_string(),
1127                variants: vec!["active".to_string(), "inactive".to_string()],
1128                default_variant: None,
1129            }],
1130            ..Default::default()
1131        };
1132        let files = generate(
1133            &schema,
1134            DatabaseKind::Postgres,
1135            &[],
1136            &HashMap::new(),
1137            false,
1138            TimeCrate::Chrono,
1139        )
1140        .unwrap();
1141        assert_eq!(files.len(), 1);
1142        assert_eq!(files[0].filename, "types.rs");
1143    }
1144
1145    #[test]
1146    fn test_generate_enums_composites_domains_single_types_file() {
1147        let schema = SchemaInfo {
1148            enums: vec![EnumInfo {
1149                schema_name: "public".to_string(),
1150                name: "status".to_string(),
1151                variants: vec!["active".to_string()],
1152                default_variant: None,
1153            }],
1154            composite_types: vec![CompositeTypeInfo {
1155                schema_name: "public".to_string(),
1156                name: "address".to_string(),
1157                fields: vec![make_col("street", "text")],
1158            }],
1159            domains: vec![DomainInfo {
1160                schema_name: "public".to_string(),
1161                name: "email".to_string(),
1162                base_type: "text".to_string(),
1163            }],
1164            ..Default::default()
1165        };
1166        let files = generate(
1167            &schema,
1168            DatabaseKind::Postgres,
1169            &[],
1170            &HashMap::new(),
1171            false,
1172            TimeCrate::Chrono,
1173        )
1174        .unwrap();
1175        // Should produce exactly 1 types.rs
1176        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
1177        assert_eq!(types_files.len(), 1);
1178    }
1179
1180    #[test]
1181    fn test_generate_tables_and_enums() {
1182        let schema = SchemaInfo {
1183            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1184            enums: vec![EnumInfo {
1185                schema_name: "public".to_string(),
1186                name: "status".to_string(),
1187                variants: vec!["active".to_string()],
1188                default_variant: None,
1189            }],
1190            ..Default::default()
1191        };
1192        let files = generate(
1193            &schema,
1194            DatabaseKind::Postgres,
1195            &[],
1196            &HashMap::new(),
1197            false,
1198            TimeCrate::Chrono,
1199        )
1200        .unwrap();
1201        assert_eq!(files.len(), 2); // users.rs + types.rs
1202    }
1203
1204    #[test]
1205    fn test_generate_filename_normalized() {
1206        let schema = SchemaInfo {
1207            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
1208            ..Default::default()
1209        };
1210        let files = generate(
1211            &schema,
1212            DatabaseKind::Postgres,
1213            &[],
1214            &HashMap::new(),
1215            false,
1216            TimeCrate::Chrono,
1217        )
1218        .unwrap();
1219        assert_eq!(files[0].filename, "user_data.rs");
1220    }
1221
1222    #[test]
1223    fn test_generate_no_origin_for_tables() {
1224        let schema = SchemaInfo {
1225            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1226            ..Default::default()
1227        };
1228        let files = generate(
1229            &schema,
1230            DatabaseKind::Postgres,
1231            &[],
1232            &HashMap::new(),
1233            false,
1234            TimeCrate::Chrono,
1235        )
1236        .unwrap();
1237        assert_eq!(files[0].origin, None);
1238    }
1239
1240    #[test]
1241    fn test_generate_types_no_origin() {
1242        let schema = SchemaInfo {
1243            enums: vec![EnumInfo {
1244                schema_name: "public".to_string(),
1245                name: "status".to_string(),
1246                variants: vec!["active".to_string()],
1247                default_variant: None,
1248            }],
1249            ..Default::default()
1250        };
1251        let files = generate(
1252            &schema,
1253            DatabaseKind::Postgres,
1254            &[],
1255            &HashMap::new(),
1256            false,
1257            TimeCrate::Chrono,
1258        )
1259        .unwrap();
1260        assert_eq!(files[0].origin, None);
1261    }
1262
1263    #[test]
1264    fn test_generate_single_file_filters_super_types_imports() {
1265        let schema = SchemaInfo {
1266            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1267            enums: vec![EnumInfo {
1268                schema_name: "public".to_string(),
1269                name: "status".to_string(),
1270                variants: vec!["active".to_string()],
1271                default_variant: None,
1272            }],
1273            ..Default::default()
1274        };
1275        let files = generate(
1276            &schema,
1277            DatabaseKind::Postgres,
1278            &[],
1279            &HashMap::new(),
1280            true,
1281            TimeCrate::Chrono,
1282        )
1283        .unwrap();
1284        // struct file should not have super::types:: imports
1285        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1286        assert!(!struct_file.code.contains("super::types::"));
1287    }
1288
1289    #[test]
1290    fn test_generate_multi_file_keeps_super_types_imports() {
1291        // Table with a column referencing an enum
1292        let schema = SchemaInfo {
1293            tables: vec![make_table("users", vec![make_col("status", "status")])],
1294            enums: vec![EnumInfo {
1295                schema_name: "public".to_string(),
1296                name: "status".to_string(),
1297                variants: vec!["active".to_string()],
1298                default_variant: None,
1299            }],
1300            ..Default::default()
1301        };
1302        let files = generate(
1303            &schema,
1304            DatabaseKind::Postgres,
1305            &[],
1306            &HashMap::new(),
1307            false,
1308            TimeCrate::Chrono,
1309        )
1310        .unwrap();
1311        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1312        assert!(struct_file.code.contains("super::types::"));
1313    }
1314
1315    #[test]
1316    fn test_generate_extra_derives_in_struct() {
1317        let schema = SchemaInfo {
1318            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1319            ..Default::default()
1320        };
1321        let derives = vec!["Serialize".to_string()];
1322        let files = generate(
1323            &schema,
1324            DatabaseKind::Postgres,
1325            &derives,
1326            &HashMap::new(),
1327            false,
1328            TimeCrate::Chrono,
1329        )
1330        .unwrap();
1331        assert!(files[0].code.contains("Serialize"));
1332    }
1333
1334    #[test]
1335    fn test_generate_extra_derives_in_enum() {
1336        let schema = SchemaInfo {
1337            enums: vec![EnumInfo {
1338                schema_name: "public".to_string(),
1339                name: "status".to_string(),
1340                variants: vec!["active".to_string()],
1341                default_variant: None,
1342            }],
1343            ..Default::default()
1344        };
1345        let derives = vec!["Serialize".to_string()];
1346        let files = generate(
1347            &schema,
1348            DatabaseKind::Postgres,
1349            &derives,
1350            &HashMap::new(),
1351            false,
1352            TimeCrate::Chrono,
1353        )
1354        .unwrap();
1355        assert!(files[0].code.contains("Serialize"));
1356    }
1357
1358    #[test]
1359    fn test_generate_type_overrides_in_struct() {
1360        let mut overrides = HashMap::new();
1361        overrides.insert("jsonb".to_string(), "MyJson".to_string());
1362        let schema = SchemaInfo {
1363            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1364            ..Default::default()
1365        };
1366        let files = generate(
1367            &schema,
1368            DatabaseKind::Postgres,
1369            &[],
1370            &overrides,
1371            false,
1372            TimeCrate::Chrono,
1373        )
1374        .unwrap();
1375        assert!(files[0].code.contains("MyJson"));
1376    }
1377
1378    #[test]
1379    fn test_generate_valid_rust_syntax() {
1380        let schema = SchemaInfo {
1381            tables: vec![make_table(
1382                "users",
1383                vec![make_col("id", "int4"), make_col("name", "text")],
1384            )],
1385            enums: vec![EnumInfo {
1386                schema_name: "public".to_string(),
1387                name: "status".to_string(),
1388                variants: vec!["active".to_string(), "inactive".to_string()],
1389                default_variant: None,
1390            }],
1391            ..Default::default()
1392        };
1393        let files = generate(
1394            &schema,
1395            DatabaseKind::Postgres,
1396            &[],
1397            &HashMap::new(),
1398            false,
1399            TimeCrate::Chrono,
1400        )
1401        .unwrap();
1402        for f in &files {
1403            // Should be parseable as valid Rust
1404            let parse_result = syn::parse_file(&f.code);
1405            assert!(
1406                parse_result.is_ok(),
1407                "Failed to parse {}: {:?}",
1408                f.filename,
1409                parse_result.err()
1410            );
1411        }
1412    }
1413
1414    // ========== generate() — views ==========
1415
1416    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1417        TableInfo {
1418            schema_name: "public".to_string(),
1419            name: name.to_string(),
1420            columns,
1421        }
1422    }
1423
1424    #[test]
1425    fn test_generate_one_view() {
1426        let schema = SchemaInfo {
1427            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1428            ..Default::default()
1429        };
1430        let files = generate(
1431            &schema,
1432            DatabaseKind::Postgres,
1433            &[],
1434            &HashMap::new(),
1435            false,
1436            TimeCrate::Chrono,
1437        )
1438        .unwrap();
1439        assert_eq!(files.len(), 1);
1440        assert_eq!(files[0].filename, "active_users.rs");
1441    }
1442
1443    #[test]
1444    fn test_generate_no_origin_for_views() {
1445        let schema = SchemaInfo {
1446            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1447            ..Default::default()
1448        };
1449        let files = generate(
1450            &schema,
1451            DatabaseKind::Postgres,
1452            &[],
1453            &HashMap::new(),
1454            false,
1455            TimeCrate::Chrono,
1456        )
1457        .unwrap();
1458        assert_eq!(files[0].origin, None);
1459    }
1460
1461    #[test]
1462    fn test_generate_tables_and_views() {
1463        let schema = SchemaInfo {
1464            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1465            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1466            ..Default::default()
1467        };
1468        let files = generate(
1469            &schema,
1470            DatabaseKind::Postgres,
1471            &[],
1472            &HashMap::new(),
1473            false,
1474            TimeCrate::Chrono,
1475        )
1476        .unwrap();
1477        assert_eq!(files.len(), 2);
1478    }
1479
1480    #[test]
1481    fn test_generate_view_valid_rust() {
1482        let schema = SchemaInfo {
1483            views: vec![make_view(
1484                "active_users",
1485                vec![make_col("id", "int4"), make_col("name", "text")],
1486            )],
1487            ..Default::default()
1488        };
1489        let files = generate(
1490            &schema,
1491            DatabaseKind::Postgres,
1492            &[],
1493            &HashMap::new(),
1494            false,
1495            TimeCrate::Chrono,
1496        )
1497        .unwrap();
1498        let parse_result = syn::parse_file(&files[0].code);
1499        assert!(
1500            parse_result.is_ok(),
1501            "Failed to parse: {:?}",
1502            parse_result.err()
1503        );
1504    }
1505
1506    #[test]
1507    fn test_generate_view_nullable_column() {
1508        let schema = SchemaInfo {
1509            views: vec![make_view(
1510                "v",
1511                vec![ColumnInfo {
1512                    name: "email".to_string(),
1513                    data_type: "text".to_string(),
1514                    udt_name: "text".to_string(),
1515                    is_nullable: true,
1516                    is_primary_key: false,
1517                    ordinal_position: 0,
1518                    schema_name: "public".to_string(),
1519                    udt_schema: None,
1520                    column_default: None,
1521                }],
1522            )],
1523            ..Default::default()
1524        };
1525        let files = generate(
1526            &schema,
1527            DatabaseKind::Postgres,
1528            &[],
1529            &HashMap::new(),
1530            false,
1531            TimeCrate::Chrono,
1532        )
1533        .unwrap();
1534        assert!(files[0].code.contains("Option<String>"));
1535    }
1536
1537    #[test]
1538    fn test_generate_collision_both_prefixed() {
1539        let schema = SchemaInfo {
1540            tables: vec![
1541                make_table("users", vec![make_col("id", "int4")]),
1542                TableInfo {
1543                    schema_name: "billing".to_string(),
1544                    name: "users".to_string(),
1545                    columns: vec![make_col("id", "int4")],
1546                },
1547            ],
1548            ..Default::default()
1549        };
1550        let files = generate(
1551            &schema,
1552            DatabaseKind::Postgres,
1553            &[],
1554            &HashMap::new(),
1555            false,
1556            TimeCrate::Chrono,
1557        )
1558        .unwrap();
1559        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1560        assert!(filenames.contains(&"users.rs"));
1561        assert!(filenames.contains(&"billing_users.rs"));
1562    }
1563
1564    #[test]
1565    fn test_generate_no_collision_no_prefix() {
1566        let schema = SchemaInfo {
1567            tables: vec![
1568                make_table("users", vec![make_col("id", "int4")]),
1569                TableInfo {
1570                    schema_name: "billing".to_string(),
1571                    name: "invoices".to_string(),
1572                    columns: vec![make_col("id", "int4")],
1573                },
1574            ],
1575            ..Default::default()
1576        };
1577        let files = generate(
1578            &schema,
1579            DatabaseKind::Postgres,
1580            &[],
1581            &HashMap::new(),
1582            false,
1583            TimeCrate::Chrono,
1584        )
1585        .unwrap();
1586        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1587        assert!(filenames.contains(&"users.rs"));
1588        assert!(filenames.contains(&"invoices.rs"));
1589    }
1590
1591    #[test]
1592    fn test_generate_single_schema_no_prefix() {
1593        let schema = SchemaInfo {
1594            tables: vec![
1595                make_table("users", vec![make_col("id", "int4")]),
1596                make_table("posts", vec![make_col("id", "int4")]),
1597            ],
1598            ..Default::default()
1599        };
1600        let files = generate(
1601            &schema,
1602            DatabaseKind::Postgres,
1603            &[],
1604            &HashMap::new(),
1605            false,
1606            TimeCrate::Chrono,
1607        )
1608        .unwrap();
1609        assert_eq!(files[0].filename, "users.rs");
1610        assert_eq!(files[1].filename, "posts.rs");
1611    }
1612
1613    #[test]
1614    fn test_generate_view_single_file_mode() {
1615        let schema = SchemaInfo {
1616            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1617            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1618            ..Default::default()
1619        };
1620        let files = generate(
1621            &schema,
1622            DatabaseKind::Postgres,
1623            &[],
1624            &HashMap::new(),
1625            true,
1626            TimeCrate::Chrono,
1627        )
1628        .unwrap();
1629        assert_eq!(files.len(), 2);
1630    }
1631
1632    // ========== parse_pg_enum_default ==========
1633
1634    #[test]
1635    fn test_parse_pg_enum_default_simple() {
1636        assert_eq!(
1637            parse_pg_enum_default("'idle'::task_status"),
1638            Some("idle".to_string())
1639        );
1640    }
1641
1642    #[test]
1643    fn test_parse_pg_enum_default_schema_qualified() {
1644        assert_eq!(
1645            parse_pg_enum_default("'active'::public.task_status"),
1646            Some("active".to_string())
1647        );
1648    }
1649
1650    #[test]
1651    fn test_parse_pg_enum_default_not_enum() {
1652        // No single-quote pattern
1653        assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1654    }
1655
1656    #[test]
1657    fn test_parse_pg_enum_default_no_cast() {
1658        assert_eq!(parse_pg_enum_default("'hello'"), None);
1659    }
1660
1661    #[test]
1662    fn test_parse_pg_enum_default_empty() {
1663        assert_eq!(parse_pg_enum_default(""), None);
1664    }
1665
1666    // ========== extract_enum_defaults ==========
1667
1668    #[test]
1669    fn test_extract_enum_defaults_from_column() {
1670        let schema = SchemaInfo {
1671            tables: vec![TableInfo {
1672                schema_name: "public".to_string(),
1673                name: "tasks".to_string(),
1674                columns: vec![ColumnInfo {
1675                    name: "status".to_string(),
1676                    data_type: "USER-DEFINED".to_string(),
1677                    udt_name: "task_status".to_string(),
1678                    is_nullable: false,
1679                    is_primary_key: false,
1680                    ordinal_position: 0,
1681                    schema_name: "public".to_string(),
1682                    udt_schema: None,
1683                    column_default: Some("'idle'::task_status".to_string()),
1684                }],
1685            }],
1686            enums: vec![EnumInfo {
1687                schema_name: "public".to_string(),
1688                name: "task_status".to_string(),
1689                variants: vec!["idle".to_string(), "running".to_string()],
1690                default_variant: None,
1691            }],
1692            ..Default::default()
1693        };
1694        let defaults = extract_enum_defaults(&schema);
1695        assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1696    }
1697
1698    #[test]
1699    fn test_extract_enum_defaults_no_default() {
1700        let schema = SchemaInfo {
1701            tables: vec![TableInfo {
1702                schema_name: "public".to_string(),
1703                name: "tasks".to_string(),
1704                columns: vec![ColumnInfo {
1705                    name: "status".to_string(),
1706                    data_type: "USER-DEFINED".to_string(),
1707                    udt_name: "task_status".to_string(),
1708                    is_nullable: false,
1709                    is_primary_key: false,
1710                    ordinal_position: 0,
1711                    schema_name: "public".to_string(),
1712                    udt_schema: None,
1713                    column_default: None,
1714                }],
1715            }],
1716            enums: vec![EnumInfo {
1717                schema_name: "public".to_string(),
1718                name: "task_status".to_string(),
1719                variants: vec!["idle".to_string()],
1720                default_variant: None,
1721            }],
1722            ..Default::default()
1723        };
1724        let defaults = extract_enum_defaults(&schema);
1725        assert!(defaults.is_empty());
1726    }
1727
1728    #[test]
1729    fn test_extract_enum_defaults_non_enum_column_ignored() {
1730        let schema = SchemaInfo {
1731            tables: vec![TableInfo {
1732                schema_name: "public".to_string(),
1733                name: "users".to_string(),
1734                columns: vec![ColumnInfo {
1735                    name: "name".to_string(),
1736                    data_type: "character varying".to_string(),
1737                    udt_name: "varchar".to_string(),
1738                    is_nullable: false,
1739                    is_primary_key: false,
1740                    ordinal_position: 0,
1741                    schema_name: "public".to_string(),
1742                    udt_schema: None,
1743                    column_default: Some("'hello'::character varying".to_string()),
1744                }],
1745            }],
1746            enums: vec![],
1747            ..Default::default()
1748        };
1749        let defaults = extract_enum_defaults(&schema);
1750        assert!(defaults.is_empty());
1751    }
1752
1753    #[test]
1754    fn test_generate_enum_with_default() {
1755        let schema = SchemaInfo {
1756            tables: vec![TableInfo {
1757                schema_name: "public".to_string(),
1758                name: "tasks".to_string(),
1759                columns: vec![ColumnInfo {
1760                    name: "status".to_string(),
1761                    data_type: "USER-DEFINED".to_string(),
1762                    udt_name: "task_status".to_string(),
1763                    is_nullable: false,
1764                    is_primary_key: false,
1765                    ordinal_position: 0,
1766                    schema_name: "public".to_string(),
1767                    udt_schema: None,
1768                    column_default: Some("'idle'::task_status".to_string()),
1769                }],
1770            }],
1771            enums: vec![EnumInfo {
1772                schema_name: "public".to_string(),
1773                name: "task_status".to_string(),
1774                variants: vec!["idle".to_string(), "running".to_string()],
1775                default_variant: None,
1776            }],
1777            ..Default::default()
1778        };
1779        let files = generate(
1780            &schema,
1781            DatabaseKind::Postgres,
1782            &[],
1783            &HashMap::new(),
1784            false,
1785            TimeCrate::Chrono,
1786        )
1787        .unwrap();
1788        let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1789        assert!(types_file.code.contains("impl Default for TaskStatus"));
1790        assert!(types_file.code.contains("Self::Idle"));
1791    }
1792}