Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod identifiers;
7pub mod naming;
8pub mod struct_gen;
9
10use std::collections::{BTreeSet, HashMap};
11use std::path::Path;
12
13use proc_macro2::TokenStream;
14
15use crate::cli::{DatabaseKind, TimeCrate};
16use crate::introspect::SchemaInfo;
17
18/// Rust reserved keywords that cannot be used as identifiers.
19const RUST_KEYWORDS: &[&str] = &[
20    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
21    "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub",
22    "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type",
23    "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", "final",
24    "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
25];
26
27/// Returns true if the given name is a Rust reserved keyword.
28pub fn is_rust_keyword(name: &str) -> bool {
29    RUST_KEYWORDS.contains(&name)
30}
31
32/// Returns the imports needed for well-known extra derives.
33pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
34    let mut imports = Vec::new();
35    let has = |name: &str| extra_derives.iter().any(|d| d == name);
36    if has("Serialize") || has("Deserialize") {
37        let mut parts = Vec::new();
38        if has("Serialize") {
39            parts.push("Serialize");
40        }
41        if has("Deserialize") {
42            parts.push("Deserialize");
43        }
44        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
45    }
46    imports
47}
48
49/// Normalize a table name for use as a Rust module/filename:
50/// replace multiple consecutive underscores with a single one.
51pub fn normalize_module_name(name: &str) -> String {
52    let mut result = String::with_capacity(name.len());
53    let mut prev_underscore = false;
54    for c in name.chars() {
55        if c == '_' {
56            if !prev_underscore {
57                result.push(c);
58            }
59            prev_underscore = true;
60        } else {
61            prev_underscore = false;
62            result.push(c);
63        }
64    }
65    result
66}
67
68/// Well-known default schemas that don't need a prefix in filenames.
69const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
70
71/// Returns true if the schema is a well-known default (public, main, dbo).
72pub fn is_default_schema(schema: &str) -> bool {
73    DEFAULT_SCHEMAS.contains(&schema)
74}
75
76/// Compute the Rust identifier for an enum / composite / domain.
77///
78/// When the same SQL name is declared in more than one schema (e.g. both
79/// `auth.role` and `billing.role` exist), the non-default-schema variants get
80/// a `SchemaName` PascalCase prefix to avoid Rust-level identifier collisions.
81/// Otherwise the bare PascalCase of the SQL name is used.
82pub fn rust_type_name_for(schema_info: &SchemaInfo, schema: &str, name: &str) -> String {
83    use heck::ToUpperCamelCase;
84    if type_name_has_cross_schema_collision(schema_info, name) && !is_default_schema(schema) {
85        format!(
86            "{}{}",
87            schema.to_upper_camel_case(),
88            name.to_upper_camel_case()
89        )
90    } else {
91        name.to_upper_camel_case()
92    }
93}
94
95/// Compute the schemas that must appear in PostgreSQL's `search_path` for
96/// the generated code to resolve every emitted unqualified `type_name`.
97///
98/// Returns the deduplicated, sorted list of non-default schemas hosting
99/// enums/composites/domains in `schema_info`. The caller can feed this into
100/// the pool's connect-hook, e.g.:
101///
102/// ```ignore
103/// let schemas = sqlx_gen::codegen::required_pg_search_path(&info).join(", ");
104/// sqlx::query(&format!("SET search_path TO public, {}", schemas))
105///     .execute(&pool).await?;
106/// ```
107pub fn required_pg_search_path(schema_info: &SchemaInfo) -> Vec<String> {
108    let mut schemas: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
109    for e in &schema_info.enums {
110        if !is_default_schema(&e.schema_name) {
111            schemas.insert(e.schema_name.clone());
112        }
113    }
114    for c in &schema_info.composite_types {
115        if !is_default_schema(&c.schema_name) {
116            schemas.insert(c.schema_name.clone());
117        }
118    }
119    for d in &schema_info.domains {
120        if !is_default_schema(&d.schema_name) {
121            schemas.insert(d.schema_name.clone());
122        }
123    }
124    schemas.into_iter().collect()
125}
126
127/// True when the SQL `name` is declared by enums / composites / domains living
128/// in more than one schema.
129pub fn type_name_has_cross_schema_collision(schema_info: &SchemaInfo, name: &str) -> bool {
130    let mut schemas: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
131    schemas.extend(
132        schema_info
133            .enums
134            .iter()
135            .filter(|e| e.name == name)
136            .map(|e| e.schema_name.as_str()),
137    );
138    schemas.extend(
139        schema_info
140            .composite_types
141            .iter()
142            .filter(|c| c.name == name)
143            .map(|c| c.schema_name.as_str()),
144    );
145    schemas.extend(
146        schema_info
147            .domains
148            .iter()
149            .filter(|d| d.name == name)
150            .map(|d| d.schema_name.as_str()),
151    );
152    schemas.len() > 1
153}
154
155/// Build a module name, prefixing with schema only when the name collides
156/// (same table name exists in multiple schemas).
157pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
158    if name_collides && !is_default_schema(schema_name) {
159        normalize_module_name(&format!("{}_{}", schema_name, table_name))
160    } else {
161        normalize_module_name(table_name)
162    }
163}
164
165/// Find table/view names that appear in more than one schema.
166fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
167    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
168    for t in &schema_info.tables {
169        seen.entry(t.name.as_str())
170            .or_default()
171            .insert(t.schema_name.as_str());
172    }
173    for v in &schema_info.views {
174        seen.entry(v.name.as_str())
175            .or_default()
176            .insert(v.schema_name.as_str());
177    }
178    seen.into_iter()
179        .filter(|(_, schemas)| schemas.len() > 1)
180        .map(|(name, _)| name)
181        .collect()
182}
183
184/// A generated code file with its content and required imports.
185#[derive(Debug, Clone)]
186pub struct GeneratedFile {
187    pub filename: String,
188    /// Optional origin comment (e.g. "Table: schema.name")
189    pub origin: Option<String>,
190    pub code: String,
191}
192
193/// Generate all code for a given schema.
194pub fn generate(
195    schema_info: &SchemaInfo,
196    db_kind: DatabaseKind,
197    extra_derives: &[String],
198    type_overrides: &HashMap<String, String>,
199    single_file: bool,
200    time_crate: TimeCrate,
201) -> crate::error::Result<Vec<GeneratedFile>> {
202    generate_with_domain_style(
203        schema_info,
204        db_kind,
205        extra_derives,
206        type_overrides,
207        single_file,
208        time_crate,
209        crate::cli::DomainStyle::Alias,
210    )
211}
212
213/// Same as [`generate`] but lets the caller pick how Postgres domains are
214/// rendered (alias vs newtype).
215pub fn generate_with_domain_style(
216    schema_info: &SchemaInfo,
217    db_kind: DatabaseKind,
218    extra_derives: &[String],
219    type_overrides: &HashMap<String, String>,
220    single_file: bool,
221    time_crate: TimeCrate,
222    domain_style: crate::cli::DomainStyle,
223) -> crate::error::Result<Vec<GeneratedFile>> {
224    let mut files = Vec::new();
225
226    // Detect table/view names that appear in multiple schemas (collisions)
227    let colliding_names = find_colliding_names(schema_info);
228
229    // Generate struct files for each table
230    for table in &schema_info.tables {
231        let (tokens, imports) = struct_gen::generate_struct(
232            table,
233            db_kind,
234            schema_info,
235            extra_derives,
236            type_overrides,
237            false,
238            time_crate,
239        );
240        let imports = filter_imports(&imports, single_file);
241        let code = format_tokens_with_imports(&tokens, &imports)?;
242        let module_name = build_module_name(
243            &table.schema_name,
244            &table.name,
245            colliding_names.contains(table.name.as_str()),
246        );
247        files.push(GeneratedFile {
248            filename: format!("{}.rs", module_name),
249            origin: None,
250            code,
251        });
252    }
253
254    // Generate struct files for each view
255    for view in &schema_info.views {
256        let (tokens, imports) = struct_gen::generate_struct(
257            view,
258            db_kind,
259            schema_info,
260            extra_derives,
261            type_overrides,
262            true,
263            time_crate,
264        );
265        let imports = filter_imports(&imports, single_file);
266        let code = format_tokens_with_imports(&tokens, &imports)?;
267        let module_name = build_module_name(
268            &view.schema_name,
269            &view.name,
270            colliding_names.contains(view.name.as_str()),
271        );
272        files.push(GeneratedFile {
273            filename: format!("{}.rs", module_name),
274            origin: None,
275            code,
276        });
277    }
278
279    // Generate types file (enums, composites, domains)
280    // Each item is formatted individually so we can insert blank lines between them.
281    let mut types_blocks: Vec<String> = Vec::new();
282    let mut types_imports = BTreeSet::new();
283
284    // Enrich enums with default variants extracted from column defaults
285    let enum_defaults = extract_enum_defaults(schema_info);
286    for enum_info in &schema_info.enums {
287        enum_gen::check_variant_collisions(enum_info)?;
288        let mut enriched = enum_info.clone();
289        if enriched.default_variant.is_none() {
290            if let Some(default) = enum_defaults.get(&enum_info.name) {
291                enriched.default_variant = Some(default.clone());
292            }
293        }
294        let (tokens, imports) =
295            enum_gen::generate_enum_with_schema(&enriched, db_kind, extra_derives, schema_info);
296        types_blocks.push(format_tokens(&tokens)?);
297        types_imports.extend(imports);
298    }
299
300    for composite in &schema_info.composite_types {
301        let (tokens, imports) = composite_gen::generate_composite(
302            composite,
303            db_kind,
304            schema_info,
305            extra_derives,
306            type_overrides,
307            time_crate,
308        );
309        types_blocks.push(format_tokens(&tokens)?);
310        types_imports.extend(imports);
311    }
312
313    for domain in &schema_info.domains {
314        let (tokens, imports) = domain_gen::generate_domain_with_style(
315            domain,
316            db_kind,
317            schema_info,
318            type_overrides,
319            time_crate,
320            domain_style,
321        );
322        types_blocks.push(format_tokens(&tokens)?);
323        types_imports.extend(imports);
324    }
325
326    if !types_blocks.is_empty() {
327        let import_lines: String = types_imports.iter().map(|i| format!("{}\n", i)).collect();
328        let body = types_blocks.join("\n");
329        let code = if import_lines.is_empty() {
330            body
331        } else {
332            format!("{}\n\n{}", import_lines.trim_end(), body)
333        };
334        files.push(GeneratedFile {
335            filename: "types.rs".to_string(),
336            origin: None,
337            code,
338        });
339    }
340
341    Ok(files)
342}
343
344/// Extract default variant values for enums by scanning column defaults across all tables and views.
345/// PostgreSQL column defaults look like `'idle'::task_status` or `'active'::public.task_status`.
346fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
347    let mut defaults: HashMap<String, String> = HashMap::new();
348
349    let all_columns = schema_info
350        .tables
351        .iter()
352        .chain(schema_info.views.iter())
353        .flat_map(|t| t.columns.iter());
354
355    for col in all_columns {
356        let default_expr = match &col.column_default {
357            Some(d) => d,
358            None => continue,
359        };
360
361        // Strip leading underscore for array types to get the base enum name
362        let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
363
364        // Check if this column references a known enum
365        let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
366        if enum_match.is_none() {
367            continue;
368        }
369
370        // Parse PG default: 'variant'::type_name
371        if let Some(variant) = parse_pg_enum_default(default_expr) {
372            defaults.entry(base_udt.to_string()).or_insert(variant);
373        }
374    }
375
376    defaults
377}
378
379/// Parse a PostgreSQL column default expression to extract the enum variant.
380/// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`.
381fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
382    // Pattern: 'value'::some_type
383    let after_opening = default_expr.trim().strip_prefix('\'')?;
384    let end_quote = after_opening.find('\'')?;
385    let value = &after_opening[..end_quote];
386    let rest = &after_opening[end_quote + 1..];
387    if rest.starts_with("::") {
388        return Some(value.to_string());
389    }
390    None
391}
392
393/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
394fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
395    if single_file {
396        imports
397            .iter()
398            .filter(|i| !i.contains("super::types::"))
399            .cloned()
400            .collect()
401    } else {
402        imports.clone()
403    }
404}
405
406/// Detect `tab_spaces` from `rustfmt.toml` or `.rustfmt.toml` by walking up
407/// from `start_dir`. Returns 4 (rustfmt default) if no config is found.
408pub fn detect_tab_spaces(start_dir: &Path) -> usize {
409    let mut dir = if start_dir.is_file() {
410        start_dir.parent().unwrap_or(start_dir)
411    } else {
412        start_dir
413    };
414    loop {
415        for name in &["rustfmt.toml", ".rustfmt.toml"] {
416            let candidate = dir.join(name);
417            if let Ok(content) = std::fs::read_to_string(&candidate) {
418                for line in content.lines() {
419                    let line = line.trim();
420                    if let Some(rest) = line.strip_prefix("tab_spaces") {
421                        let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
422                        if let Ok(n) = rest.trim().parse::<usize>() {
423                            return n;
424                        }
425                    }
426                }
427                // Config found but no tab_spaces key → use rustfmt default
428                return 4;
429            }
430        }
431        match dir.parent() {
432            Some(parent) => dir = parent,
433            None => return 4,
434        }
435    }
436}
437
438/// Parse and format a TokenStream via prettyplease, then post-process spacing.
439/// `tab_spaces` controls how many spaces per indentation level for SQL inside raw strings.
440pub(crate) fn parse_and_format(tokens: &TokenStream) -> crate::error::Result<String> {
441    parse_and_format_with_tab_spaces(tokens, 4)
442}
443
444pub(crate) fn parse_and_format_with_tab_spaces(
445    tokens: &TokenStream,
446    tab_spaces: usize,
447) -> crate::error::Result<String> {
448    let file = syn::parse2::<syn::File>(tokens.clone()).map_err(|e| {
449        crate::error::Error::Config(format!(
450            "Internal sqlx-gen bug: failed to parse generated code: {}. \
451             Raw tokens:\n  {}\n\
452             Please report this with the input schema.",
453            e, tokens
454        ))
455    })?;
456    let raw = prettyplease::unparse(&file);
457    let raw = indent_multiline_raw_strings(&raw, tab_spaces);
458    Ok(add_blank_lines_between_items(&raw))
459}
460
461/// Format a single TokenStream block (no imports).
462pub(crate) fn format_tokens(tokens: &TokenStream) -> crate::error::Result<String> {
463    parse_and_format(tokens)
464}
465
466pub fn format_tokens_with_imports(
467    tokens: &TokenStream,
468    imports: &BTreeSet<String>,
469) -> crate::error::Result<String> {
470    format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
471}
472
473pub fn format_tokens_with_imports_and_tab_spaces(
474    tokens: &TokenStream,
475    imports: &BTreeSet<String>,
476    tab_spaces: usize,
477) -> crate::error::Result<String> {
478    let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces)?;
479
480    let used_imports: Vec<&String> = imports
481        .iter()
482        .filter(|imp| is_import_used(imp, &formatted))
483        .collect();
484
485    if used_imports.is_empty() {
486        Ok(formatted)
487    } else {
488        let import_lines: String = used_imports.iter().map(|i| format!("{}\n", i)).collect();
489        Ok(format!("{}\n\n{}", import_lines.trim_end(), formatted))
490    }
491}
492
493/// Check if an import is actually used in the generated code.
494/// Extracts the imported type names and checks if they appear in the code.
495fn is_import_used(import: &str, code: &str) -> bool {
496    // "use foo::bar::Baz;" → check for "Baz"
497    // "use foo::{A, B};" → check for "A" or "B"
498    // "use foo::bar::*;" → always keep
499    let trimmed = import.trim().trim_end_matches(';');
500    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
501
502    if path.ends_with("::*") {
503        return true;
504    }
505
506    // Handle grouped imports: use foo::{A, B, C};
507    if let Some(start) = path.find('{') {
508        if let Some(end) = path.find('}') {
509            let names = &path[start + 1..end];
510            return names
511                .split(',')
512                .map(|n| n.trim())
513                .filter(|n| !n.is_empty())
514                .any(|name| code.contains(name));
515        }
516    }
517
518    // Simple import: use foo::Bar;
519    if let Some(name) = path.rsplit("::").next() {
520        return code.contains(name);
521    }
522
523    true
524}
525
526/// Indent the content of multi-line raw string literals (`r#"..."#`) so SQL
527/// reads naturally in generated code. All SQL raw strings live inside `impl`
528/// methods, so content is indented at a fixed 2-level depth and relative
529/// indentation between SQL lines is preserved (e.g. SET items indented under
530/// the SET keyword).
531fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
532    // Raw string content is NOT reformatted by external formatters, so we must
533    // bake the right indentation at generation time.
534    // The closing "# aligns with the r#" argument level (3 indent levels deep),
535    // and SQL content gets one extra level beyond that.
536    let close_indent = 4 + tab_spaces; // impl(4) + fn_arg(tab)
537    let sql_indent = 4 + 2 * tab_spaces; // impl(4) + fn_arg(tab) + sql(tab)
538
539    let lines: Vec<&str> = code.lines().collect();
540    let mut result = Vec::with_capacity(lines.len());
541    let mut inside_raw = false;
542    let mut raw_lines: Vec<&str> = Vec::new();
543
544    for line in &lines {
545        if !inside_raw {
546            if let Some(pos) = line.find("r#\"") {
547                let after = &line[pos + 3..];
548                if !after.contains("\"#") {
549                    inside_raw = true;
550                    raw_lines.clear();
551                }
552            }
553            result.push(line.to_string());
554        } else if line.trim_start().starts_with("\"#") {
555            // Find minimum indentation among non-empty content lines
556            let min_indent = raw_lines
557                .iter()
558                .filter(|l| !l.trim().is_empty())
559                .map(|l| l.len() - l.trim_start().len())
560                .min()
561                .unwrap_or(0);
562            for raw_line in &raw_lines {
563                let trimmed = raw_line.trim();
564                if trimmed.is_empty() {
565                    result.push(String::new());
566                } else {
567                    let original_indent = raw_line.len() - raw_line.trim_start().len();
568                    let relative = original_indent.saturating_sub(min_indent);
569                    result.push(format!(
570                        "{}{}{}",
571                        " ".repeat(sql_indent),
572                        " ".repeat(relative),
573                        trimmed
574                    ));
575                }
576            }
577            // Closing "# at method body level
578            let trimmed = line.trim();
579            result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
580            inside_raw = false;
581        } else {
582            raw_lines.push(line);
583        }
584    }
585
586    result.join("\n")
587}
588
589fn add_blank_lines_between_items(code: &str) -> String {
590    let lines: Vec<&str> = code.lines().collect();
591    let mut result = Vec::with_capacity(lines.len());
592
593    for (i, line) in lines.iter().enumerate() {
594        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
595        // but not for the first variant in the enum.
596        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
597            let prev = lines[i - 1].trim();
598            if prev.ends_with(',') {
599                result.push("");
600            }
601        }
602
603        // Insert a blank line before top-level items (pub struct, impl, #[derive)
604        // and before methods inside impl blocks, when preceded by a closing brace `}`
605        if i > 0 {
606            let trimmed = line.trim();
607            let prev = lines[i - 1].trim();
608            if prev == "}"
609                && (trimmed.starts_with("pub struct")
610                    || trimmed.starts_with("impl ")
611                    || trimmed.starts_with("#[derive")
612                    || trimmed.starts_with("pub async fn")
613                    || trimmed.starts_with("pub fn"))
614            {
615                result.push("");
616            }
617        }
618
619        // Insert a blank line before a new logical block inside methods:
620        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
621        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
622        if i > 0 {
623            let trimmed = line.trim();
624            let prev = lines[i - 1].trim();
625            let prev_is_await_end = prev.ends_with(".await?;")
626                || prev.ends_with(".await?")
627                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
628            if prev_is_await_end && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) {
629                result.push("");
630            }
631            // Separate a sqlx query `let` from preceding simple `let` assignments
632            if trimmed.starts_with("let ")
633                && trimmed.contains("sqlx::")
634                && prev.starts_with("let ")
635                && !prev.contains("sqlx::")
636            {
637                result.push("");
638            }
639        }
640
641        result.push(line);
642    }
643
644    result.join("\n")
645}
646
647#[cfg(test)]
648mod tests {
649    use super::*;
650    use crate::introspect::{
651        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
652    };
653    use std::collections::HashMap;
654
655    // ========== is_rust_keyword ==========
656
657    #[test]
658    fn test_keyword_type() {
659        assert!(is_rust_keyword("type"));
660    }
661
662    #[test]
663    fn test_keyword_fn() {
664        assert!(is_rust_keyword("fn"));
665    }
666
667    #[test]
668    fn test_keyword_let() {
669        assert!(is_rust_keyword("let"));
670    }
671
672    #[test]
673    fn test_keyword_match() {
674        assert!(is_rust_keyword("match"));
675    }
676
677    #[test]
678    fn test_keyword_async() {
679        assert!(is_rust_keyword("async"));
680    }
681
682    #[test]
683    fn test_keyword_await() {
684        assert!(is_rust_keyword("await"));
685    }
686
687    #[test]
688    fn test_keyword_yield() {
689        assert!(is_rust_keyword("yield"));
690    }
691
692    #[test]
693    fn test_keyword_abstract() {
694        assert!(is_rust_keyword("abstract"));
695    }
696
697    #[test]
698    fn test_keyword_try() {
699        assert!(is_rust_keyword("try"));
700    }
701
702    #[test]
703    fn test_not_keyword_name() {
704        assert!(!is_rust_keyword("name"));
705    }
706
707    #[test]
708    fn test_not_keyword_id() {
709        assert!(!is_rust_keyword("id"));
710    }
711
712    #[test]
713    fn test_not_keyword_uppercase_type() {
714        assert!(!is_rust_keyword("Type"));
715    }
716
717    // ========== normalize_module_name ==========
718
719    #[test]
720    fn test_normalize_no_underscores() {
721        assert_eq!(normalize_module_name("users"), "users");
722    }
723
724    #[test]
725    fn test_normalize_single_underscore() {
726        assert_eq!(normalize_module_name("user_roles"), "user_roles");
727    }
728
729    #[test]
730    fn test_normalize_double_underscore() {
731        assert_eq!(normalize_module_name("user__roles"), "user_roles");
732    }
733
734    #[test]
735    fn test_normalize_triple_underscore() {
736        assert_eq!(normalize_module_name("a___b"), "a_b");
737    }
738
739    #[test]
740    fn test_normalize_leading_underscore() {
741        assert_eq!(normalize_module_name("_private"), "_private");
742    }
743
744    #[test]
745    fn test_normalize_trailing_underscore() {
746        assert_eq!(normalize_module_name("name_"), "name_");
747    }
748
749    #[test]
750    fn test_normalize_double_leading() {
751        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
752    }
753
754    #[test]
755    fn test_normalize_multiple_groups() {
756        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
757    }
758
759    // ========== build_module_name ==========
760
761    #[test]
762    fn test_build_no_collision_no_prefix() {
763        assert_eq!(build_module_name("public", "users", false), "users");
764    }
765
766    #[test]
767    fn test_build_no_collision_non_default_no_prefix() {
768        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
769    }
770
771    #[test]
772    fn test_build_collision_prefixed() {
773        assert_eq!(build_module_name("billing", "users", true), "billing_users");
774    }
775
776    #[test]
777    fn test_build_collision_default_schema_no_prefix() {
778        assert_eq!(build_module_name("public", "users", true), "users");
779    }
780
781    #[test]
782    fn test_build_collision_normalizes_double_underscore() {
783        assert_eq!(
784            build_module_name("billing", "agent__connector", true),
785            "billing_agent_connector"
786        );
787    }
788
789    // ========== is_default_schema ==========
790
791    #[test]
792    fn test_default_schema_public() {
793        assert!(is_default_schema("public"));
794    }
795
796    #[test]
797    fn test_default_schema_main() {
798        assert!(is_default_schema("main"));
799    }
800
801    #[test]
802    fn test_non_default_schema() {
803        assert!(!is_default_schema("billing"));
804    }
805
806    // ========== imports_for_derives ==========
807
808    #[test]
809    fn test_imports_empty() {
810        let result = imports_for_derives(&[]);
811        assert!(result.is_empty());
812    }
813
814    #[test]
815    fn test_imports_serialize_only() {
816        let derives = vec!["Serialize".to_string()];
817        let result = imports_for_derives(&derives);
818        assert_eq!(result, vec!["use serde::{Serialize};"]);
819    }
820
821    #[test]
822    fn test_imports_deserialize_only() {
823        let derives = vec!["Deserialize".to_string()];
824        let result = imports_for_derives(&derives);
825        assert_eq!(result, vec!["use serde::{Deserialize};"]);
826    }
827
828    #[test]
829    fn test_imports_both_serde() {
830        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
831        let result = imports_for_derives(&derives);
832        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
833    }
834
835    #[test]
836    fn test_imports_non_serde() {
837        let derives = vec!["Hash".to_string()];
838        let result = imports_for_derives(&derives);
839        assert!(result.is_empty());
840    }
841
842    #[test]
843    fn test_imports_non_serde_multiple() {
844        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
845        let result = imports_for_derives(&derives);
846        assert!(result.is_empty());
847    }
848
849    #[test]
850    fn test_imports_mixed_serde_and_others() {
851        let derives = vec![
852            "Serialize".to_string(),
853            "Hash".to_string(),
854            "Deserialize".to_string(),
855        ];
856        let result = imports_for_derives(&derives);
857        assert_eq!(result.len(), 1);
858        assert!(result[0].contains("Serialize"));
859        assert!(result[0].contains("Deserialize"));
860    }
861
862    // ========== add_blank_lines_between_items ==========
863
864    #[test]
865    fn test_blank_lines_between_renamed_variants() {
866        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
867        let result = add_blank_lines_between_items(input);
868        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
869    }
870
871    #[test]
872    fn test_no_blank_line_for_first_variant() {
873        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
874        let result = add_blank_lines_between_items(input);
875        // No blank line before first #[sqlx(rename because previous line is `{`
876        assert!(!result.contains("{\n\n"));
877    }
878
879    #[test]
880    fn test_no_change_without_rename() {
881        let input = "pub enum Foo {\n    A,\n    B,\n}";
882        let result = add_blank_lines_between_items(input);
883        assert_eq!(result, input);
884    }
885
886    #[test]
887    fn test_no_change_for_struct() {
888        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
889        let result = add_blank_lines_between_items(input);
890        assert_eq!(result, input);
891    }
892
893    // ========== rust_type_name_for / cross-schema collisions ==========
894
895    fn schema_with_two_role_enums() -> SchemaInfo {
896        SchemaInfo {
897            enums: vec![
898                crate::introspect::EnumInfo {
899                    schema_name: "auth".into(),
900                    name: "role".into(),
901                    variants: vec!["admin".into(), "user".into()],
902                    default_variant: None,
903                },
904                crate::introspect::EnumInfo {
905                    schema_name: "billing".into(),
906                    name: "role".into(),
907                    variants: vec!["payer".into(), "payee".into()],
908                    default_variant: None,
909                },
910            ],
911            ..Default::default()
912        }
913    }
914
915    #[test]
916    fn rust_type_name_prefixes_schema_on_cross_schema_collision() {
917        let s = schema_with_two_role_enums();
918        assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
919        assert_eq!(rust_type_name_for(&s, "billing", "role"), "BillingRole");
920    }
921
922    #[test]
923    fn rust_type_name_keeps_bare_name_when_unique() {
924        let s = SchemaInfo {
925            enums: vec![crate::introspect::EnumInfo {
926                schema_name: "auth".into(),
927                name: "role".into(),
928                variants: vec!["admin".into()],
929                default_variant: None,
930            }],
931            ..Default::default()
932        };
933        assert_eq!(rust_type_name_for(&s, "auth", "role"), "Role");
934    }
935
936    #[test]
937    fn required_search_path_collects_non_default_schemas() {
938        let s = SchemaInfo {
939            enums: vec![
940                crate::introspect::EnumInfo {
941                    schema_name: "auth".into(),
942                    name: "role".into(),
943                    variants: vec!["x".into()],
944                    default_variant: None,
945                },
946                crate::introspect::EnumInfo {
947                    schema_name: "public".into(),
948                    name: "status".into(),
949                    variants: vec!["y".into()],
950                    default_variant: None,
951                },
952            ],
953            composite_types: vec![crate::introspect::CompositeTypeInfo {
954                schema_name: "billing".into(),
955                name: "addr".into(),
956                fields: vec![],
957            }],
958            domains: vec![crate::introspect::DomainInfo {
959                schema_name: "auth".into(),
960                name: "email".into(),
961                base_type: "text".into(),
962            }],
963            ..Default::default()
964        };
965        // Sorted, deduplicated, public excluded.
966        assert_eq!(required_pg_search_path(&s), vec!["auth", "billing"]);
967    }
968
969    #[test]
970    fn required_search_path_empty_when_only_default_schema() {
971        let s = SchemaInfo {
972            enums: vec![crate::introspect::EnumInfo {
973                schema_name: "public".into(),
974                name: "status".into(),
975                variants: vec!["y".into()],
976                default_variant: None,
977            }],
978            ..Default::default()
979        };
980        assert!(required_pg_search_path(&s).is_empty());
981    }
982
983    #[test]
984    fn rust_type_name_default_schema_keeps_bare_name_even_on_collision() {
985        let s = SchemaInfo {
986            enums: vec![
987                crate::introspect::EnumInfo {
988                    schema_name: "public".into(),
989                    name: "role".into(),
990                    variants: vec!["a".into()],
991                    default_variant: None,
992                },
993                crate::introspect::EnumInfo {
994                    schema_name: "auth".into(),
995                    name: "role".into(),
996                    variants: vec!["b".into()],
997                    default_variant: None,
998                },
999            ],
1000            ..Default::default()
1001        };
1002        // public stays "Role"; auth gets the schema prefix to break the tie.
1003        assert_eq!(rust_type_name_for(&s, "public", "role"), "Role");
1004        assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
1005    }
1006
1007    // ========== filter_imports ==========
1008
1009    #[test]
1010    fn test_filter_single_file_strips_super_types() {
1011        let mut imports = BTreeSet::new();
1012        imports.insert("use super::types::Foo;".to_string());
1013        imports.insert("use chrono::NaiveDateTime;".to_string());
1014        let result = filter_imports(&imports, true);
1015        assert!(!result.contains("use super::types::Foo;"));
1016        assert!(result.contains("use chrono::NaiveDateTime;"));
1017    }
1018
1019    #[test]
1020    fn test_filter_single_file_keeps_other_imports() {
1021        let mut imports = BTreeSet::new();
1022        imports.insert("use chrono::NaiveDateTime;".to_string());
1023        let result = filter_imports(&imports, true);
1024        assert!(result.contains("use chrono::NaiveDateTime;"));
1025    }
1026
1027    #[test]
1028    fn test_filter_multi_file_keeps_all() {
1029        let mut imports = BTreeSet::new();
1030        imports.insert("use super::types::Foo;".to_string());
1031        imports.insert("use chrono::NaiveDateTime;".to_string());
1032        let result = filter_imports(&imports, false);
1033        assert_eq!(result.len(), 2);
1034    }
1035
1036    #[test]
1037    fn test_filter_empty_set() {
1038        let imports = BTreeSet::new();
1039        let result = filter_imports(&imports, true);
1040        assert!(result.is_empty());
1041    }
1042
1043    // ========== generate() orchestrator ==========
1044
1045    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1046        TableInfo {
1047            schema_name: "public".to_string(),
1048            name: name.to_string(),
1049            columns,
1050        }
1051    }
1052
1053    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
1054        ColumnInfo {
1055            name: name.to_string(),
1056            data_type: udt_name.to_string(),
1057            udt_name: udt_name.to_string(),
1058            is_nullable: false,
1059            is_primary_key: false,
1060            ordinal_position: 0,
1061            schema_name: "public".to_string(),
1062            udt_schema: None,
1063            column_default: None,
1064        }
1065    }
1066
1067    #[test]
1068    fn test_generate_empty_schema() {
1069        let schema = SchemaInfo::default();
1070        let files = generate(
1071            &schema,
1072            DatabaseKind::Postgres,
1073            &[],
1074            &HashMap::new(),
1075            false,
1076            TimeCrate::Chrono,
1077        )
1078        .unwrap();
1079        assert!(files.is_empty());
1080    }
1081
1082    #[test]
1083    fn test_generate_one_table() {
1084        let schema = SchemaInfo {
1085            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1086            ..Default::default()
1087        };
1088        let files = generate(
1089            &schema,
1090            DatabaseKind::Postgres,
1091            &[],
1092            &HashMap::new(),
1093            false,
1094            TimeCrate::Chrono,
1095        )
1096        .unwrap();
1097        assert_eq!(files.len(), 1);
1098        assert_eq!(files[0].filename, "users.rs");
1099    }
1100
1101    #[test]
1102    fn test_generate_two_tables() {
1103        let schema = SchemaInfo {
1104            tables: vec![
1105                make_table("users", vec![make_col("id", "int4")]),
1106                make_table("posts", vec![make_col("id", "int4")]),
1107            ],
1108            ..Default::default()
1109        };
1110        let files = generate(
1111            &schema,
1112            DatabaseKind::Postgres,
1113            &[],
1114            &HashMap::new(),
1115            false,
1116            TimeCrate::Chrono,
1117        )
1118        .unwrap();
1119        assert_eq!(files.len(), 2);
1120    }
1121
1122    #[test]
1123    fn test_generate_enum_creates_types_file() {
1124        let schema = SchemaInfo {
1125            enums: vec![EnumInfo {
1126                schema_name: "public".to_string(),
1127                name: "status".to_string(),
1128                variants: vec!["active".to_string(), "inactive".to_string()],
1129                default_variant: None,
1130            }],
1131            ..Default::default()
1132        };
1133        let files = generate(
1134            &schema,
1135            DatabaseKind::Postgres,
1136            &[],
1137            &HashMap::new(),
1138            false,
1139            TimeCrate::Chrono,
1140        )
1141        .unwrap();
1142        assert_eq!(files.len(), 1);
1143        assert_eq!(files[0].filename, "types.rs");
1144    }
1145
1146    #[test]
1147    fn test_generate_enums_composites_domains_single_types_file() {
1148        let schema = SchemaInfo {
1149            enums: vec![EnumInfo {
1150                schema_name: "public".to_string(),
1151                name: "status".to_string(),
1152                variants: vec!["active".to_string()],
1153                default_variant: None,
1154            }],
1155            composite_types: vec![CompositeTypeInfo {
1156                schema_name: "public".to_string(),
1157                name: "address".to_string(),
1158                fields: vec![make_col("street", "text")],
1159            }],
1160            domains: vec![DomainInfo {
1161                schema_name: "public".to_string(),
1162                name: "email".to_string(),
1163                base_type: "text".to_string(),
1164            }],
1165            ..Default::default()
1166        };
1167        let files = generate(
1168            &schema,
1169            DatabaseKind::Postgres,
1170            &[],
1171            &HashMap::new(),
1172            false,
1173            TimeCrate::Chrono,
1174        )
1175        .unwrap();
1176        // Should produce exactly 1 types.rs
1177        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
1178        assert_eq!(types_files.len(), 1);
1179    }
1180
1181    #[test]
1182    fn test_generate_tables_and_enums() {
1183        let schema = SchemaInfo {
1184            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1185            enums: vec![EnumInfo {
1186                schema_name: "public".to_string(),
1187                name: "status".to_string(),
1188                variants: vec!["active".to_string()],
1189                default_variant: None,
1190            }],
1191            ..Default::default()
1192        };
1193        let files = generate(
1194            &schema,
1195            DatabaseKind::Postgres,
1196            &[],
1197            &HashMap::new(),
1198            false,
1199            TimeCrate::Chrono,
1200        )
1201        .unwrap();
1202        assert_eq!(files.len(), 2); // users.rs + types.rs
1203    }
1204
1205    #[test]
1206    fn test_generate_filename_normalized() {
1207        let schema = SchemaInfo {
1208            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
1209            ..Default::default()
1210        };
1211        let files = generate(
1212            &schema,
1213            DatabaseKind::Postgres,
1214            &[],
1215            &HashMap::new(),
1216            false,
1217            TimeCrate::Chrono,
1218        )
1219        .unwrap();
1220        assert_eq!(files[0].filename, "user_data.rs");
1221    }
1222
1223    #[test]
1224    fn test_generate_no_origin_for_tables() {
1225        let schema = SchemaInfo {
1226            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1227            ..Default::default()
1228        };
1229        let files = generate(
1230            &schema,
1231            DatabaseKind::Postgres,
1232            &[],
1233            &HashMap::new(),
1234            false,
1235            TimeCrate::Chrono,
1236        )
1237        .unwrap();
1238        assert_eq!(files[0].origin, None);
1239    }
1240
1241    #[test]
1242    fn test_generate_types_no_origin() {
1243        let schema = SchemaInfo {
1244            enums: vec![EnumInfo {
1245                schema_name: "public".to_string(),
1246                name: "status".to_string(),
1247                variants: vec!["active".to_string()],
1248                default_variant: None,
1249            }],
1250            ..Default::default()
1251        };
1252        let files = generate(
1253            &schema,
1254            DatabaseKind::Postgres,
1255            &[],
1256            &HashMap::new(),
1257            false,
1258            TimeCrate::Chrono,
1259        )
1260        .unwrap();
1261        assert_eq!(files[0].origin, None);
1262    }
1263
1264    #[test]
1265    fn test_generate_single_file_filters_super_types_imports() {
1266        let schema = SchemaInfo {
1267            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1268            enums: vec![EnumInfo {
1269                schema_name: "public".to_string(),
1270                name: "status".to_string(),
1271                variants: vec!["active".to_string()],
1272                default_variant: None,
1273            }],
1274            ..Default::default()
1275        };
1276        let files = generate(
1277            &schema,
1278            DatabaseKind::Postgres,
1279            &[],
1280            &HashMap::new(),
1281            true,
1282            TimeCrate::Chrono,
1283        )
1284        .unwrap();
1285        // struct file should not have super::types:: imports
1286        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1287        assert!(!struct_file.code.contains("super::types::"));
1288    }
1289
1290    #[test]
1291    fn test_generate_multi_file_keeps_super_types_imports() {
1292        // Table with a column referencing an enum
1293        let schema = SchemaInfo {
1294            tables: vec![make_table("users", vec![make_col("status", "status")])],
1295            enums: vec![EnumInfo {
1296                schema_name: "public".to_string(),
1297                name: "status".to_string(),
1298                variants: vec!["active".to_string()],
1299                default_variant: None,
1300            }],
1301            ..Default::default()
1302        };
1303        let files = generate(
1304            &schema,
1305            DatabaseKind::Postgres,
1306            &[],
1307            &HashMap::new(),
1308            false,
1309            TimeCrate::Chrono,
1310        )
1311        .unwrap();
1312        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1313        assert!(struct_file.code.contains("super::types::"));
1314    }
1315
1316    #[test]
1317    fn test_generate_extra_derives_in_struct() {
1318        let schema = SchemaInfo {
1319            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1320            ..Default::default()
1321        };
1322        let derives = vec!["Serialize".to_string()];
1323        let files = generate(
1324            &schema,
1325            DatabaseKind::Postgres,
1326            &derives,
1327            &HashMap::new(),
1328            false,
1329            TimeCrate::Chrono,
1330        )
1331        .unwrap();
1332        assert!(files[0].code.contains("Serialize"));
1333    }
1334
1335    #[test]
1336    fn test_generate_extra_derives_in_enum() {
1337        let schema = SchemaInfo {
1338            enums: vec![EnumInfo {
1339                schema_name: "public".to_string(),
1340                name: "status".to_string(),
1341                variants: vec!["active".to_string()],
1342                default_variant: None,
1343            }],
1344            ..Default::default()
1345        };
1346        let derives = vec!["Serialize".to_string()];
1347        let files = generate(
1348            &schema,
1349            DatabaseKind::Postgres,
1350            &derives,
1351            &HashMap::new(),
1352            false,
1353            TimeCrate::Chrono,
1354        )
1355        .unwrap();
1356        assert!(files[0].code.contains("Serialize"));
1357    }
1358
1359    #[test]
1360    fn test_generate_type_overrides_in_struct() {
1361        let mut overrides = HashMap::new();
1362        overrides.insert("jsonb".to_string(), "MyJson".to_string());
1363        let schema = SchemaInfo {
1364            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1365            ..Default::default()
1366        };
1367        let files = generate(
1368            &schema,
1369            DatabaseKind::Postgres,
1370            &[],
1371            &overrides,
1372            false,
1373            TimeCrate::Chrono,
1374        )
1375        .unwrap();
1376        assert!(files[0].code.contains("MyJson"));
1377    }
1378
1379    #[test]
1380    fn test_generate_valid_rust_syntax() {
1381        let schema = SchemaInfo {
1382            tables: vec![make_table(
1383                "users",
1384                vec![make_col("id", "int4"), make_col("name", "text")],
1385            )],
1386            enums: vec![EnumInfo {
1387                schema_name: "public".to_string(),
1388                name: "status".to_string(),
1389                variants: vec!["active".to_string(), "inactive".to_string()],
1390                default_variant: None,
1391            }],
1392            ..Default::default()
1393        };
1394        let files = generate(
1395            &schema,
1396            DatabaseKind::Postgres,
1397            &[],
1398            &HashMap::new(),
1399            false,
1400            TimeCrate::Chrono,
1401        )
1402        .unwrap();
1403        for f in &files {
1404            // Should be parseable as valid Rust
1405            let parse_result = syn::parse_file(&f.code);
1406            assert!(
1407                parse_result.is_ok(),
1408                "Failed to parse {}: {:?}",
1409                f.filename,
1410                parse_result.err()
1411            );
1412        }
1413    }
1414
1415    // ========== generate() — views ==========
1416
1417    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1418        TableInfo {
1419            schema_name: "public".to_string(),
1420            name: name.to_string(),
1421            columns,
1422        }
1423    }
1424
1425    #[test]
1426    fn test_generate_one_view() {
1427        let schema = SchemaInfo {
1428            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1429            ..Default::default()
1430        };
1431        let files = generate(
1432            &schema,
1433            DatabaseKind::Postgres,
1434            &[],
1435            &HashMap::new(),
1436            false,
1437            TimeCrate::Chrono,
1438        )
1439        .unwrap();
1440        assert_eq!(files.len(), 1);
1441        assert_eq!(files[0].filename, "active_users.rs");
1442    }
1443
1444    #[test]
1445    fn test_generate_no_origin_for_views() {
1446        let schema = SchemaInfo {
1447            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1448            ..Default::default()
1449        };
1450        let files = generate(
1451            &schema,
1452            DatabaseKind::Postgres,
1453            &[],
1454            &HashMap::new(),
1455            false,
1456            TimeCrate::Chrono,
1457        )
1458        .unwrap();
1459        assert_eq!(files[0].origin, None);
1460    }
1461
1462    #[test]
1463    fn test_generate_tables_and_views() {
1464        let schema = SchemaInfo {
1465            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1466            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1467            ..Default::default()
1468        };
1469        let files = generate(
1470            &schema,
1471            DatabaseKind::Postgres,
1472            &[],
1473            &HashMap::new(),
1474            false,
1475            TimeCrate::Chrono,
1476        )
1477        .unwrap();
1478        assert_eq!(files.len(), 2);
1479    }
1480
1481    #[test]
1482    fn test_generate_view_valid_rust() {
1483        let schema = SchemaInfo {
1484            views: vec![make_view(
1485                "active_users",
1486                vec![make_col("id", "int4"), make_col("name", "text")],
1487            )],
1488            ..Default::default()
1489        };
1490        let files = generate(
1491            &schema,
1492            DatabaseKind::Postgres,
1493            &[],
1494            &HashMap::new(),
1495            false,
1496            TimeCrate::Chrono,
1497        )
1498        .unwrap();
1499        let parse_result = syn::parse_file(&files[0].code);
1500        assert!(
1501            parse_result.is_ok(),
1502            "Failed to parse: {:?}",
1503            parse_result.err()
1504        );
1505    }
1506
1507    #[test]
1508    fn test_generate_view_nullable_column() {
1509        let schema = SchemaInfo {
1510            views: vec![make_view(
1511                "v",
1512                vec![ColumnInfo {
1513                    name: "email".to_string(),
1514                    data_type: "text".to_string(),
1515                    udt_name: "text".to_string(),
1516                    is_nullable: true,
1517                    is_primary_key: false,
1518                    ordinal_position: 0,
1519                    schema_name: "public".to_string(),
1520                    udt_schema: None,
1521                    column_default: None,
1522                }],
1523            )],
1524            ..Default::default()
1525        };
1526        let files = generate(
1527            &schema,
1528            DatabaseKind::Postgres,
1529            &[],
1530            &HashMap::new(),
1531            false,
1532            TimeCrate::Chrono,
1533        )
1534        .unwrap();
1535        assert!(files[0].code.contains("Option<String>"));
1536    }
1537
1538    #[test]
1539    fn test_generate_collision_both_prefixed() {
1540        let schema = SchemaInfo {
1541            tables: vec![
1542                make_table("users", vec![make_col("id", "int4")]),
1543                TableInfo {
1544                    schema_name: "billing".to_string(),
1545                    name: "users".to_string(),
1546                    columns: vec![make_col("id", "int4")],
1547                },
1548            ],
1549            ..Default::default()
1550        };
1551        let files = generate(
1552            &schema,
1553            DatabaseKind::Postgres,
1554            &[],
1555            &HashMap::new(),
1556            false,
1557            TimeCrate::Chrono,
1558        )
1559        .unwrap();
1560        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1561        assert!(filenames.contains(&"users.rs"));
1562        assert!(filenames.contains(&"billing_users.rs"));
1563    }
1564
1565    #[test]
1566    fn test_generate_no_collision_no_prefix() {
1567        let schema = SchemaInfo {
1568            tables: vec![
1569                make_table("users", vec![make_col("id", "int4")]),
1570                TableInfo {
1571                    schema_name: "billing".to_string(),
1572                    name: "invoices".to_string(),
1573                    columns: vec![make_col("id", "int4")],
1574                },
1575            ],
1576            ..Default::default()
1577        };
1578        let files = generate(
1579            &schema,
1580            DatabaseKind::Postgres,
1581            &[],
1582            &HashMap::new(),
1583            false,
1584            TimeCrate::Chrono,
1585        )
1586        .unwrap();
1587        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1588        assert!(filenames.contains(&"users.rs"));
1589        assert!(filenames.contains(&"invoices.rs"));
1590    }
1591
1592    #[test]
1593    fn test_generate_single_schema_no_prefix() {
1594        let schema = SchemaInfo {
1595            tables: vec![
1596                make_table("users", vec![make_col("id", "int4")]),
1597                make_table("posts", vec![make_col("id", "int4")]),
1598            ],
1599            ..Default::default()
1600        };
1601        let files = generate(
1602            &schema,
1603            DatabaseKind::Postgres,
1604            &[],
1605            &HashMap::new(),
1606            false,
1607            TimeCrate::Chrono,
1608        )
1609        .unwrap();
1610        assert_eq!(files[0].filename, "users.rs");
1611        assert_eq!(files[1].filename, "posts.rs");
1612    }
1613
1614    #[test]
1615    fn test_generate_view_single_file_mode() {
1616        let schema = SchemaInfo {
1617            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1618            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1619            ..Default::default()
1620        };
1621        let files = generate(
1622            &schema,
1623            DatabaseKind::Postgres,
1624            &[],
1625            &HashMap::new(),
1626            true,
1627            TimeCrate::Chrono,
1628        )
1629        .unwrap();
1630        assert_eq!(files.len(), 2);
1631    }
1632
1633    // ========== parse_pg_enum_default ==========
1634
1635    #[test]
1636    fn test_parse_pg_enum_default_simple() {
1637        assert_eq!(
1638            parse_pg_enum_default("'idle'::task_status"),
1639            Some("idle".to_string())
1640        );
1641    }
1642
1643    #[test]
1644    fn test_parse_pg_enum_default_schema_qualified() {
1645        assert_eq!(
1646            parse_pg_enum_default("'active'::public.task_status"),
1647            Some("active".to_string())
1648        );
1649    }
1650
1651    #[test]
1652    fn test_parse_pg_enum_default_not_enum() {
1653        // No single-quote pattern
1654        assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1655    }
1656
1657    #[test]
1658    fn test_parse_pg_enum_default_no_cast() {
1659        assert_eq!(parse_pg_enum_default("'hello'"), None);
1660    }
1661
1662    #[test]
1663    fn test_parse_pg_enum_default_empty() {
1664        assert_eq!(parse_pg_enum_default(""), None);
1665    }
1666
1667    // ========== extract_enum_defaults ==========
1668
1669    #[test]
1670    fn test_extract_enum_defaults_from_column() {
1671        let schema = SchemaInfo {
1672            tables: vec![TableInfo {
1673                schema_name: "public".to_string(),
1674                name: "tasks".to_string(),
1675                columns: vec![ColumnInfo {
1676                    name: "status".to_string(),
1677                    data_type: "USER-DEFINED".to_string(),
1678                    udt_name: "task_status".to_string(),
1679                    is_nullable: false,
1680                    is_primary_key: false,
1681                    ordinal_position: 0,
1682                    schema_name: "public".to_string(),
1683                    udt_schema: None,
1684                    column_default: Some("'idle'::task_status".to_string()),
1685                }],
1686            }],
1687            enums: vec![EnumInfo {
1688                schema_name: "public".to_string(),
1689                name: "task_status".to_string(),
1690                variants: vec!["idle".to_string(), "running".to_string()],
1691                default_variant: None,
1692            }],
1693            ..Default::default()
1694        };
1695        let defaults = extract_enum_defaults(&schema);
1696        assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1697    }
1698
1699    #[test]
1700    fn test_extract_enum_defaults_no_default() {
1701        let schema = SchemaInfo {
1702            tables: vec![TableInfo {
1703                schema_name: "public".to_string(),
1704                name: "tasks".to_string(),
1705                columns: vec![ColumnInfo {
1706                    name: "status".to_string(),
1707                    data_type: "USER-DEFINED".to_string(),
1708                    udt_name: "task_status".to_string(),
1709                    is_nullable: false,
1710                    is_primary_key: false,
1711                    ordinal_position: 0,
1712                    schema_name: "public".to_string(),
1713                    udt_schema: None,
1714                    column_default: None,
1715                }],
1716            }],
1717            enums: vec![EnumInfo {
1718                schema_name: "public".to_string(),
1719                name: "task_status".to_string(),
1720                variants: vec!["idle".to_string()],
1721                default_variant: None,
1722            }],
1723            ..Default::default()
1724        };
1725        let defaults = extract_enum_defaults(&schema);
1726        assert!(defaults.is_empty());
1727    }
1728
1729    #[test]
1730    fn test_extract_enum_defaults_non_enum_column_ignored() {
1731        let schema = SchemaInfo {
1732            tables: vec![TableInfo {
1733                schema_name: "public".to_string(),
1734                name: "users".to_string(),
1735                columns: vec![ColumnInfo {
1736                    name: "name".to_string(),
1737                    data_type: "character varying".to_string(),
1738                    udt_name: "varchar".to_string(),
1739                    is_nullable: false,
1740                    is_primary_key: false,
1741                    ordinal_position: 0,
1742                    schema_name: "public".to_string(),
1743                    udt_schema: None,
1744                    column_default: Some("'hello'::character varying".to_string()),
1745                }],
1746            }],
1747            enums: vec![],
1748            ..Default::default()
1749        };
1750        let defaults = extract_enum_defaults(&schema);
1751        assert!(defaults.is_empty());
1752    }
1753
1754    #[test]
1755    fn test_generate_enum_with_default() {
1756        let schema = SchemaInfo {
1757            tables: vec![TableInfo {
1758                schema_name: "public".to_string(),
1759                name: "tasks".to_string(),
1760                columns: vec![ColumnInfo {
1761                    name: "status".to_string(),
1762                    data_type: "USER-DEFINED".to_string(),
1763                    udt_name: "task_status".to_string(),
1764                    is_nullable: false,
1765                    is_primary_key: false,
1766                    ordinal_position: 0,
1767                    schema_name: "public".to_string(),
1768                    udt_schema: None,
1769                    column_default: Some("'idle'::task_status".to_string()),
1770                }],
1771            }],
1772            enums: vec![EnumInfo {
1773                schema_name: "public".to_string(),
1774                name: "task_status".to_string(),
1775                variants: vec!["idle".to_string(), "running".to_string()],
1776                default_variant: None,
1777            }],
1778            ..Default::default()
1779        };
1780        let files = generate(
1781            &schema,
1782            DatabaseKind::Postgres,
1783            &[],
1784            &HashMap::new(),
1785            false,
1786            TimeCrate::Chrono,
1787        )
1788        .unwrap();
1789        let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1790        assert!(types_file.code.contains("impl Default for TaskStatus"));
1791        assert!(types_file.code.contains("Self::Idle"));
1792    }
1793}