Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15/// Rust reserved keywords that cannot be used as identifiers.
16const RUST_KEYWORDS: &[&str] = &[
17    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24/// Returns true if the given name is a Rust reserved keyword.
25pub fn is_rust_keyword(name: &str) -> bool {
26    RUST_KEYWORDS.contains(&name)
27}
28
29/// Returns the imports needed for well-known extra derives.
30pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31    let mut imports = Vec::new();
32    let has = |name: &str| extra_derives.iter().any(|d| d == name);
33    if has("Serialize") || has("Deserialize") {
34        let mut parts = Vec::new();
35        if has("Serialize") {
36            parts.push("Serialize");
37        }
38        if has("Deserialize") {
39            parts.push("Deserialize");
40        }
41        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42    }
43    imports
44}
45
46/// Normalize a table name for use as a Rust module/filename:
47/// replace multiple consecutive underscores with a single one.
48pub fn normalize_module_name(name: &str) -> String {
49    let mut result = String::with_capacity(name.len());
50    let mut prev_underscore = false;
51    for c in name.chars() {
52        if c == '_' {
53            if !prev_underscore {
54                result.push(c);
55            }
56            prev_underscore = true;
57        } else {
58            prev_underscore = false;
59            result.push(c);
60        }
61    }
62    result
63}
64
65/// Well-known default schemas that don't need a prefix in filenames.
66const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68/// Returns true if the schema is a well-known default (public, main, dbo).
69pub fn is_default_schema(schema: &str) -> bool {
70    DEFAULT_SCHEMAS.contains(&schema)
71}
72
73/// Build a module name, prefixing with schema only when the name collides
74/// (same table name exists in multiple schemas).
75pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76    if name_collides && !is_default_schema(schema_name) {
77        normalize_module_name(&format!("{}_{}", schema_name, table_name))
78    } else {
79        normalize_module_name(table_name)
80    }
81}
82
83/// Find table/view names that appear in more than one schema.
84fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86    for t in &schema_info.tables {
87        seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88    }
89    for v in &schema_info.views {
90        seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91    }
92    seen.into_iter()
93        .filter(|(_, schemas)| schemas.len() > 1)
94        .map(|(name, _)| name)
95        .collect()
96}
97
98/// A generated code file with its content and required imports.
99#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101    pub filename: String,
102    /// Optional origin comment (e.g. "Table: schema.name")
103    pub origin: Option<String>,
104    pub code: String,
105}
106
107/// Generate all code for a given schema.
108pub fn generate(
109    schema_info: &SchemaInfo,
110    db_kind: DatabaseKind,
111    extra_derives: &[String],
112    type_overrides: &HashMap<String, String>,
113    single_file: bool,
114) -> Vec<GeneratedFile> {
115    let mut files = Vec::new();
116
117    // Detect table/view names that appear in multiple schemas (collisions)
118    let colliding_names = find_colliding_names(schema_info);
119
120    // Generate struct files for each table
121    for table in &schema_info.tables {
122        let (tokens, imports) =
123            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124        let imports = filter_imports(&imports, single_file);
125        let code = format_tokens_with_imports(&tokens, &imports);
126        let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127        let origin = format!("Table: {}.{}", table.schema_name, table.name);
128        files.push(GeneratedFile {
129            filename: format!("{}.rs", module_name),
130            origin: Some(origin),
131            code,
132        });
133    }
134
135    // Generate struct files for each view
136    for view in &schema_info.views {
137        let (tokens, imports) =
138            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
139        let imports = filter_imports(&imports, single_file);
140        let code = format_tokens_with_imports(&tokens, &imports);
141        let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
142        let origin = format!("View: {}.{}", view.schema_name, view.name);
143        files.push(GeneratedFile {
144            filename: format!("{}.rs", module_name),
145            origin: Some(origin),
146            code,
147        });
148    }
149
150    // Generate types file (enums, composites, domains)
151    // Each item is formatted individually so we can insert blank lines between them.
152    let mut types_blocks: Vec<String> = Vec::new();
153    let mut types_imports = BTreeSet::new();
154
155    // Enrich enums with default variants extracted from column defaults
156    let enum_defaults = extract_enum_defaults(schema_info);
157    for enum_info in &schema_info.enums {
158        let mut enriched = enum_info.clone();
159        if enriched.default_variant.is_none() {
160            if let Some(default) = enum_defaults.get(&enum_info.name) {
161                enriched.default_variant = Some(default.clone());
162            }
163        }
164        let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
165        types_blocks.push(format_tokens(&tokens));
166        types_imports.extend(imports);
167    }
168
169    for composite in &schema_info.composite_types {
170        let (tokens, imports) = composite_gen::generate_composite(
171            composite,
172            db_kind,
173            schema_info,
174            extra_derives,
175            type_overrides,
176        );
177        types_blocks.push(format_tokens(&tokens));
178        types_imports.extend(imports);
179    }
180
181    for domain in &schema_info.domains {
182        let (tokens, imports) =
183            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
184        types_blocks.push(format_tokens(&tokens));
185        types_imports.extend(imports);
186    }
187
188    if !types_blocks.is_empty() {
189        let import_lines: String = types_imports
190            .iter()
191            .map(|i| format!("{}\n", i))
192            .collect();
193        let body = types_blocks.join("\n");
194        let code = if import_lines.is_empty() {
195            body
196        } else {
197            format!("{}\n\n{}", import_lines.trim_end(), body)
198        };
199        files.push(GeneratedFile {
200            filename: "types.rs".to_string(),
201            origin: None,
202            code,
203        });
204    }
205
206    files
207}
208
209/// Extract default variant values for enums by scanning column defaults across all tables and views.
210/// PostgreSQL column defaults look like `'idle'::task_status` or `'active'::public.task_status`.
211fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
212    let mut defaults: HashMap<String, String> = HashMap::new();
213
214    let all_columns = schema_info
215        .tables
216        .iter()
217        .chain(schema_info.views.iter())
218        .flat_map(|t| t.columns.iter());
219
220    for col in all_columns {
221        let default_expr = match &col.column_default {
222            Some(d) => d,
223            None => continue,
224        };
225
226        // Strip leading underscore for array types to get the base enum name
227        let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
228
229        // Check if this column references a known enum
230        let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
231        if enum_match.is_none() {
232            continue;
233        }
234
235        // Parse PG default: 'variant'::type_name
236        if let Some(variant) = parse_pg_enum_default(default_expr) {
237            defaults.entry(base_udt.to_string()).or_insert(variant);
238        }
239    }
240
241    defaults
242}
243
244/// Parse a PostgreSQL column default expression to extract the enum variant.
245/// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`.
246fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
247    // Pattern: 'value'::some_type
248    let stripped = default_expr.trim();
249    if stripped.starts_with('\'') {
250        if let Some(end_quote) = stripped[1..].find('\'') {
251            let value = &stripped[1..1 + end_quote];
252            // Verify there's a :: cast after the closing quote
253            let rest = &stripped[2 + end_quote..];
254            if rest.starts_with("::") {
255                return Some(value.to_string());
256            }
257        }
258    }
259    None
260}
261
262/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
263fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
264    if single_file {
265        imports
266            .iter()
267            .filter(|i| !i.contains("super::types::"))
268            .cloned()
269            .collect()
270    } else {
271        imports.clone()
272    }
273}
274
275/// Parse and format a TokenStream via prettyplease, then post-process spacing.
276pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
277    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
278        log::error!("Failed to parse generated code: {}", e);
279        log::error!("This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
280        std::process::exit(1);
281    });
282    let raw = prettyplease::unparse(&file);
283    add_blank_lines_between_items(&raw)
284}
285
286/// Format a single TokenStream block (no imports).
287pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
288    parse_and_format(tokens)
289}
290
291pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
292    let formatted = parse_and_format(tokens);
293
294    let used_imports: Vec<&String> = imports
295        .iter()
296        .filter(|imp| is_import_used(imp, &formatted))
297        .collect();
298
299    if used_imports.is_empty() {
300        formatted
301    } else {
302        let import_lines: String = used_imports
303            .iter()
304            .map(|i| format!("{}\n", i))
305            .collect();
306        format!("{}\n\n{}", import_lines.trim_end(), formatted)
307    }
308}
309
310/// Check if an import is actually used in the generated code.
311/// Extracts the imported type names and checks if they appear in the code.
312fn is_import_used(import: &str, code: &str) -> bool {
313    // "use foo::bar::Baz;" → check for "Baz"
314    // "use foo::{A, B};" → check for "A" or "B"
315    // "use foo::bar::*;" → always keep
316    let trimmed = import.trim().trim_end_matches(';');
317    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
318
319    if path.ends_with("::*") {
320        return true;
321    }
322
323    // Handle grouped imports: use foo::{A, B, C};
324    if let Some(start) = path.find('{') {
325        if let Some(end) = path.find('}') {
326            let names = &path[start + 1..end];
327            return names
328                .split(',')
329                .map(|n| n.trim())
330                .filter(|n| !n.is_empty())
331                .any(|name| code.contains(name));
332        }
333    }
334
335    // Simple import: use foo::Bar;
336    if let Some(name) = path.rsplit("::").next() {
337        return code.contains(name);
338    }
339
340    true
341}
342
343/// Post-process formatted code to:
344/// - Add blank lines between enum variants with `#[sqlx(rename`
345/// - Add blank lines between top-level items (structs, impls)
346/// - Add blank lines between logical blocks inside async methods
347fn add_blank_lines_between_items(code: &str) -> String {
348    let lines: Vec<&str> = code.lines().collect();
349    let mut result = Vec::with_capacity(lines.len());
350
351    for (i, line) in lines.iter().enumerate() {
352        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
353        // but not for the first variant in the enum.
354        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
355            let prev = lines[i - 1].trim();
356            if prev.ends_with(',') {
357                result.push("");
358            }
359        }
360
361        // Insert a blank line before top-level items (pub struct, impl, #[derive)
362        // and before methods inside impl blocks, when preceded by a closing brace `}`
363        if i > 0 {
364            let trimmed = line.trim();
365            let prev = lines[i - 1].trim();
366            if prev == "}"
367                && (trimmed.starts_with("pub struct")
368                    || trimmed.starts_with("impl ")
369                    || trimmed.starts_with("#[derive")
370                    || trimmed.starts_with("pub async fn")
371                    || trimmed.starts_with("pub fn"))
372            {
373                result.push("");
374            }
375        }
376
377        // Insert a blank line before a new logical block inside methods:
378        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
379        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
380        if i > 0 {
381            let trimmed = line.trim();
382            let prev = lines[i - 1].trim();
383            let prev_is_await_end = prev.ends_with(".await?;")
384                || prev.ends_with(".await?")
385                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
386            if prev_is_await_end
387                && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
388            {
389                result.push("");
390            }
391            // Separate a sqlx query `let` from preceding simple `let` assignments
392            if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
393                && prev.starts_with("let ") && !prev.contains("sqlx::")
394            {
395                result.push("");
396            }
397        }
398
399        result.push(line);
400    }
401
402    result.join("\n")
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::introspect::{
409        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
410    };
411    use std::collections::HashMap;
412
413    // ========== is_rust_keyword ==========
414
415    #[test]
416    fn test_keyword_type() {
417        assert!(is_rust_keyword("type"));
418    }
419
420    #[test]
421    fn test_keyword_fn() {
422        assert!(is_rust_keyword("fn"));
423    }
424
425    #[test]
426    fn test_keyword_let() {
427        assert!(is_rust_keyword("let"));
428    }
429
430    #[test]
431    fn test_keyword_match() {
432        assert!(is_rust_keyword("match"));
433    }
434
435    #[test]
436    fn test_keyword_async() {
437        assert!(is_rust_keyword("async"));
438    }
439
440    #[test]
441    fn test_keyword_await() {
442        assert!(is_rust_keyword("await"));
443    }
444
445    #[test]
446    fn test_keyword_yield() {
447        assert!(is_rust_keyword("yield"));
448    }
449
450    #[test]
451    fn test_keyword_abstract() {
452        assert!(is_rust_keyword("abstract"));
453    }
454
455    #[test]
456    fn test_keyword_try() {
457        assert!(is_rust_keyword("try"));
458    }
459
460    #[test]
461    fn test_not_keyword_name() {
462        assert!(!is_rust_keyword("name"));
463    }
464
465    #[test]
466    fn test_not_keyword_id() {
467        assert!(!is_rust_keyword("id"));
468    }
469
470    #[test]
471    fn test_not_keyword_uppercase_type() {
472        assert!(!is_rust_keyword("Type"));
473    }
474
475    // ========== normalize_module_name ==========
476
477    #[test]
478    fn test_normalize_no_underscores() {
479        assert_eq!(normalize_module_name("users"), "users");
480    }
481
482    #[test]
483    fn test_normalize_single_underscore() {
484        assert_eq!(normalize_module_name("user_roles"), "user_roles");
485    }
486
487    #[test]
488    fn test_normalize_double_underscore() {
489        assert_eq!(normalize_module_name("user__roles"), "user_roles");
490    }
491
492    #[test]
493    fn test_normalize_triple_underscore() {
494        assert_eq!(normalize_module_name("a___b"), "a_b");
495    }
496
497    #[test]
498    fn test_normalize_leading_underscore() {
499        assert_eq!(normalize_module_name("_private"), "_private");
500    }
501
502    #[test]
503    fn test_normalize_trailing_underscore() {
504        assert_eq!(normalize_module_name("name_"), "name_");
505    }
506
507    #[test]
508    fn test_normalize_double_leading() {
509        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
510    }
511
512    #[test]
513    fn test_normalize_multiple_groups() {
514        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
515    }
516
517    // ========== build_module_name ==========
518
519    #[test]
520    fn test_build_no_collision_no_prefix() {
521        assert_eq!(build_module_name("public", "users", false), "users");
522    }
523
524    #[test]
525    fn test_build_no_collision_non_default_no_prefix() {
526        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
527    }
528
529    #[test]
530    fn test_build_collision_prefixed() {
531        assert_eq!(build_module_name("billing", "users", true), "billing_users");
532    }
533
534    #[test]
535    fn test_build_collision_default_schema_no_prefix() {
536        assert_eq!(build_module_name("public", "users", true), "users");
537    }
538
539    #[test]
540    fn test_build_collision_normalizes_double_underscore() {
541        assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
542    }
543
544    // ========== is_default_schema ==========
545
546    #[test]
547    fn test_default_schema_public() {
548        assert!(is_default_schema("public"));
549    }
550
551    #[test]
552    fn test_default_schema_main() {
553        assert!(is_default_schema("main"));
554    }
555
556    #[test]
557    fn test_non_default_schema() {
558        assert!(!is_default_schema("billing"));
559    }
560
561    // ========== imports_for_derives ==========
562
563    #[test]
564    fn test_imports_empty() {
565        let result = imports_for_derives(&[]);
566        assert!(result.is_empty());
567    }
568
569    #[test]
570    fn test_imports_serialize_only() {
571        let derives = vec!["Serialize".to_string()];
572        let result = imports_for_derives(&derives);
573        assert_eq!(result, vec!["use serde::{Serialize};"]);
574    }
575
576    #[test]
577    fn test_imports_deserialize_only() {
578        let derives = vec!["Deserialize".to_string()];
579        let result = imports_for_derives(&derives);
580        assert_eq!(result, vec!["use serde::{Deserialize};"]);
581    }
582
583    #[test]
584    fn test_imports_both_serde() {
585        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
586        let result = imports_for_derives(&derives);
587        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
588    }
589
590    #[test]
591    fn test_imports_non_serde() {
592        let derives = vec!["Hash".to_string()];
593        let result = imports_for_derives(&derives);
594        assert!(result.is_empty());
595    }
596
597    #[test]
598    fn test_imports_non_serde_multiple() {
599        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
600        let result = imports_for_derives(&derives);
601        assert!(result.is_empty());
602    }
603
604    #[test]
605    fn test_imports_mixed_serde_and_others() {
606        let derives = vec![
607            "Serialize".to_string(),
608            "Hash".to_string(),
609            "Deserialize".to_string(),
610        ];
611        let result = imports_for_derives(&derives);
612        assert_eq!(result.len(), 1);
613        assert!(result[0].contains("Serialize"));
614        assert!(result[0].contains("Deserialize"));
615    }
616
617    // ========== add_blank_lines_between_items ==========
618
619    #[test]
620    fn test_blank_lines_between_renamed_variants() {
621        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
622        let result = add_blank_lines_between_items(input);
623        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
624    }
625
626    #[test]
627    fn test_no_blank_line_for_first_variant() {
628        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
629        let result = add_blank_lines_between_items(input);
630        // No blank line before first #[sqlx(rename because previous line is `{`
631        assert!(!result.contains("{\n\n"));
632    }
633
634    #[test]
635    fn test_no_change_without_rename() {
636        let input = "pub enum Foo {\n    A,\n    B,\n}";
637        let result = add_blank_lines_between_items(input);
638        assert_eq!(result, input);
639    }
640
641    #[test]
642    fn test_no_change_for_struct() {
643        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
644        let result = add_blank_lines_between_items(input);
645        assert_eq!(result, input);
646    }
647
648    // ========== filter_imports ==========
649
650    #[test]
651    fn test_filter_single_file_strips_super_types() {
652        let mut imports = BTreeSet::new();
653        imports.insert("use super::types::Foo;".to_string());
654        imports.insert("use chrono::NaiveDateTime;".to_string());
655        let result = filter_imports(&imports, true);
656        assert!(!result.contains("use super::types::Foo;"));
657        assert!(result.contains("use chrono::NaiveDateTime;"));
658    }
659
660    #[test]
661    fn test_filter_single_file_keeps_other_imports() {
662        let mut imports = BTreeSet::new();
663        imports.insert("use chrono::NaiveDateTime;".to_string());
664        let result = filter_imports(&imports, true);
665        assert!(result.contains("use chrono::NaiveDateTime;"));
666    }
667
668    #[test]
669    fn test_filter_multi_file_keeps_all() {
670        let mut imports = BTreeSet::new();
671        imports.insert("use super::types::Foo;".to_string());
672        imports.insert("use chrono::NaiveDateTime;".to_string());
673        let result = filter_imports(&imports, false);
674        assert_eq!(result.len(), 2);
675    }
676
677    #[test]
678    fn test_filter_empty_set() {
679        let imports = BTreeSet::new();
680        let result = filter_imports(&imports, true);
681        assert!(result.is_empty());
682    }
683
684    // ========== generate() orchestrator ==========
685
686    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
687        TableInfo {
688            schema_name: "public".to_string(),
689            name: name.to_string(),
690            columns,
691        }
692    }
693
694    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
695        ColumnInfo {
696            name: name.to_string(),
697            data_type: udt_name.to_string(),
698            udt_name: udt_name.to_string(),
699            is_nullable: false,
700            is_primary_key: false,
701            ordinal_position: 0,
702            schema_name: "public".to_string(),
703            column_default: None,
704        }
705    }
706
707    #[test]
708    fn test_generate_empty_schema() {
709        let schema = SchemaInfo::default();
710        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
711        assert!(files.is_empty());
712    }
713
714    #[test]
715    fn test_generate_one_table() {
716        let schema = SchemaInfo {
717            tables: vec![make_table("users", vec![make_col("id", "int4")])],
718            ..Default::default()
719        };
720        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
721        assert_eq!(files.len(), 1);
722        assert_eq!(files[0].filename, "users.rs");
723    }
724
725    #[test]
726    fn test_generate_two_tables() {
727        let schema = SchemaInfo {
728            tables: vec![
729                make_table("users", vec![make_col("id", "int4")]),
730                make_table("posts", vec![make_col("id", "int4")]),
731            ],
732            ..Default::default()
733        };
734        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
735        assert_eq!(files.len(), 2);
736    }
737
738    #[test]
739    fn test_generate_enum_creates_types_file() {
740        let schema = SchemaInfo {
741            enums: vec![EnumInfo {
742                schema_name: "public".to_string(),
743                name: "status".to_string(),
744                variants: vec!["active".to_string(), "inactive".to_string()],
745                default_variant: None,
746            }],
747            ..Default::default()
748        };
749        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
750        assert_eq!(files.len(), 1);
751        assert_eq!(files[0].filename, "types.rs");
752    }
753
754    #[test]
755    fn test_generate_enums_composites_domains_single_types_file() {
756        let schema = SchemaInfo {
757            enums: vec![EnumInfo {
758                schema_name: "public".to_string(),
759                name: "status".to_string(),
760                variants: vec!["active".to_string()],
761                default_variant: None,
762            }],
763            composite_types: vec![CompositeTypeInfo {
764                schema_name: "public".to_string(),
765                name: "address".to_string(),
766                fields: vec![make_col("street", "text")],
767            }],
768            domains: vec![DomainInfo {
769                schema_name: "public".to_string(),
770                name: "email".to_string(),
771                base_type: "text".to_string(),
772            }],
773            ..Default::default()
774        };
775        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
776        // Should produce exactly 1 types.rs
777        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
778        assert_eq!(types_files.len(), 1);
779    }
780
781    #[test]
782    fn test_generate_tables_and_enums() {
783        let schema = SchemaInfo {
784            tables: vec![make_table("users", vec![make_col("id", "int4")])],
785            enums: vec![EnumInfo {
786                schema_name: "public".to_string(),
787                name: "status".to_string(),
788                variants: vec!["active".to_string()],
789                default_variant: None,
790            }],
791            ..Default::default()
792        };
793        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
794        assert_eq!(files.len(), 2); // users.rs + types.rs
795    }
796
797    #[test]
798    fn test_generate_filename_normalized() {
799        let schema = SchemaInfo {
800            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
801            ..Default::default()
802        };
803        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
804        assert_eq!(files[0].filename, "user_data.rs");
805    }
806
807    #[test]
808    fn test_generate_origin_correct() {
809        let schema = SchemaInfo {
810            tables: vec![make_table("users", vec![make_col("id", "int4")])],
811            ..Default::default()
812        };
813        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
814        assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
815    }
816
817    #[test]
818    fn test_generate_types_no_origin() {
819        let schema = SchemaInfo {
820            enums: vec![EnumInfo {
821                schema_name: "public".to_string(),
822                name: "status".to_string(),
823                variants: vec!["active".to_string()],
824                default_variant: None,
825            }],
826            ..Default::default()
827        };
828        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
829        assert_eq!(files[0].origin, None);
830    }
831
832    #[test]
833    fn test_generate_single_file_filters_super_types_imports() {
834        let schema = SchemaInfo {
835            tables: vec![make_table("users", vec![make_col("id", "int4")])],
836            enums: vec![EnumInfo {
837                schema_name: "public".to_string(),
838                name: "status".to_string(),
839                variants: vec!["active".to_string()],
840                default_variant: None,
841            }],
842            ..Default::default()
843        };
844        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
845        // struct file should not have super::types:: imports
846        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
847        assert!(!struct_file.code.contains("super::types::"));
848    }
849
850    #[test]
851    fn test_generate_multi_file_keeps_super_types_imports() {
852        // Table with a column referencing an enum
853        let schema = SchemaInfo {
854            tables: vec![make_table("users", vec![make_col("status", "status")])],
855            enums: vec![EnumInfo {
856                schema_name: "public".to_string(),
857                name: "status".to_string(),
858                variants: vec!["active".to_string()],
859                default_variant: None,
860            }],
861            ..Default::default()
862        };
863        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
864        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
865        assert!(struct_file.code.contains("super::types::"));
866    }
867
868    #[test]
869    fn test_generate_extra_derives_in_struct() {
870        let schema = SchemaInfo {
871            tables: vec![make_table("users", vec![make_col("id", "int4")])],
872            ..Default::default()
873        };
874        let derives = vec!["Serialize".to_string()];
875        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
876        assert!(files[0].code.contains("Serialize"));
877    }
878
879    #[test]
880    fn test_generate_extra_derives_in_enum() {
881        let schema = SchemaInfo {
882            enums: vec![EnumInfo {
883                schema_name: "public".to_string(),
884                name: "status".to_string(),
885                variants: vec!["active".to_string()],
886                default_variant: None,
887            }],
888            ..Default::default()
889        };
890        let derives = vec!["Serialize".to_string()];
891        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
892        assert!(files[0].code.contains("Serialize"));
893    }
894
895    #[test]
896    fn test_generate_type_overrides_in_struct() {
897        let mut overrides = HashMap::new();
898        overrides.insert("jsonb".to_string(), "MyJson".to_string());
899        let schema = SchemaInfo {
900            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
901            ..Default::default()
902        };
903        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
904        assert!(files[0].code.contains("MyJson"));
905    }
906
907    #[test]
908    fn test_generate_valid_rust_syntax() {
909        let schema = SchemaInfo {
910            tables: vec![make_table("users", vec![
911                make_col("id", "int4"),
912                make_col("name", "text"),
913            ])],
914            enums: vec![EnumInfo {
915                schema_name: "public".to_string(),
916                name: "status".to_string(),
917                variants: vec!["active".to_string(), "inactive".to_string()],
918                default_variant: None,
919            }],
920            ..Default::default()
921        };
922        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
923        for f in &files {
924            // Should be parseable as valid Rust
925            let parse_result = syn::parse_file(&f.code);
926            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
927        }
928    }
929
930    // ========== generate() — views ==========
931
932    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
933        TableInfo {
934            schema_name: "public".to_string(),
935            name: name.to_string(),
936            columns,
937        }
938    }
939
940    #[test]
941    fn test_generate_one_view() {
942        let schema = SchemaInfo {
943            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
944            ..Default::default()
945        };
946        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
947        assert_eq!(files.len(), 1);
948        assert_eq!(files[0].filename, "active_users.rs");
949    }
950
951    #[test]
952    fn test_generate_view_origin() {
953        let schema = SchemaInfo {
954            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
955            ..Default::default()
956        };
957        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
958        assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
959    }
960
961    #[test]
962    fn test_generate_tables_and_views() {
963        let schema = SchemaInfo {
964            tables: vec![make_table("users", vec![make_col("id", "int4")])],
965            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
966            ..Default::default()
967        };
968        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
969        assert_eq!(files.len(), 2);
970    }
971
972    #[test]
973    fn test_generate_view_valid_rust() {
974        let schema = SchemaInfo {
975            views: vec![make_view("active_users", vec![
976                make_col("id", "int4"),
977                make_col("name", "text"),
978            ])],
979            ..Default::default()
980        };
981        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
982        let parse_result = syn::parse_file(&files[0].code);
983        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
984    }
985
986    #[test]
987    fn test_generate_view_nullable_column() {
988        let schema = SchemaInfo {
989            views: vec![make_view("v", vec![ColumnInfo {
990                name: "email".to_string(),
991                data_type: "text".to_string(),
992                udt_name: "text".to_string(),
993                is_nullable: true,
994                is_primary_key: false,
995                ordinal_position: 0,
996                schema_name: "public".to_string(),
997                column_default: None,
998            }])],
999            ..Default::default()
1000        };
1001        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1002        assert!(files[0].code.contains("Option<String>"));
1003    }
1004
1005    #[test]
1006    fn test_generate_collision_both_prefixed() {
1007        let schema = SchemaInfo {
1008            tables: vec![
1009                make_table("users", vec![make_col("id", "int4")]),
1010                TableInfo {
1011                    schema_name: "billing".to_string(),
1012                    name: "users".to_string(),
1013                    columns: vec![make_col("id", "int4")],
1014                },
1015            ],
1016            ..Default::default()
1017        };
1018        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1019        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1020        assert!(filenames.contains(&"users.rs"));
1021        assert!(filenames.contains(&"billing_users.rs"));
1022    }
1023
1024    #[test]
1025    fn test_generate_no_collision_no_prefix() {
1026        let schema = SchemaInfo {
1027            tables: vec![
1028                make_table("users", vec![make_col("id", "int4")]),
1029                TableInfo {
1030                    schema_name: "billing".to_string(),
1031                    name: "invoices".to_string(),
1032                    columns: vec![make_col("id", "int4")],
1033                },
1034            ],
1035            ..Default::default()
1036        };
1037        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1038        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1039        assert!(filenames.contains(&"users.rs"));
1040        assert!(filenames.contains(&"invoices.rs"));
1041    }
1042
1043    #[test]
1044    fn test_generate_single_schema_no_prefix() {
1045        let schema = SchemaInfo {
1046            tables: vec![
1047                make_table("users", vec![make_col("id", "int4")]),
1048                make_table("posts", vec![make_col("id", "int4")]),
1049            ],
1050            ..Default::default()
1051        };
1052        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1053        assert_eq!(files[0].filename, "users.rs");
1054        assert_eq!(files[1].filename, "posts.rs");
1055    }
1056
1057    #[test]
1058    fn test_generate_view_single_file_mode() {
1059        let schema = SchemaInfo {
1060            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1061            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1062            ..Default::default()
1063        };
1064        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
1065        assert_eq!(files.len(), 2);
1066    }
1067
1068    // ========== parse_pg_enum_default ==========
1069
1070    #[test]
1071    fn test_parse_pg_enum_default_simple() {
1072        assert_eq!(
1073            parse_pg_enum_default("'idle'::task_status"),
1074            Some("idle".to_string())
1075        );
1076    }
1077
1078    #[test]
1079    fn test_parse_pg_enum_default_schema_qualified() {
1080        assert_eq!(
1081            parse_pg_enum_default("'active'::public.task_status"),
1082            Some("active".to_string())
1083        );
1084    }
1085
1086    #[test]
1087    fn test_parse_pg_enum_default_not_enum() {
1088        // No single-quote pattern
1089        assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1090    }
1091
1092    #[test]
1093    fn test_parse_pg_enum_default_no_cast() {
1094        assert_eq!(parse_pg_enum_default("'hello'"), None);
1095    }
1096
1097    #[test]
1098    fn test_parse_pg_enum_default_empty() {
1099        assert_eq!(parse_pg_enum_default(""), None);
1100    }
1101
1102    // ========== extract_enum_defaults ==========
1103
1104    #[test]
1105    fn test_extract_enum_defaults_from_column() {
1106        let schema = SchemaInfo {
1107            tables: vec![TableInfo {
1108                schema_name: "public".to_string(),
1109                name: "tasks".to_string(),
1110                columns: vec![ColumnInfo {
1111                    name: "status".to_string(),
1112                    data_type: "USER-DEFINED".to_string(),
1113                    udt_name: "task_status".to_string(),
1114                    is_nullable: false,
1115                    is_primary_key: false,
1116                    ordinal_position: 0,
1117                    schema_name: "public".to_string(),
1118                    column_default: Some("'idle'::task_status".to_string()),
1119                }],
1120            }],
1121            enums: vec![EnumInfo {
1122                schema_name: "public".to_string(),
1123                name: "task_status".to_string(),
1124                variants: vec!["idle".to_string(), "running".to_string()],
1125                default_variant: None,
1126            }],
1127            ..Default::default()
1128        };
1129        let defaults = extract_enum_defaults(&schema);
1130        assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1131    }
1132
1133    #[test]
1134    fn test_extract_enum_defaults_no_default() {
1135        let schema = SchemaInfo {
1136            tables: vec![TableInfo {
1137                schema_name: "public".to_string(),
1138                name: "tasks".to_string(),
1139                columns: vec![ColumnInfo {
1140                    name: "status".to_string(),
1141                    data_type: "USER-DEFINED".to_string(),
1142                    udt_name: "task_status".to_string(),
1143                    is_nullable: false,
1144                    is_primary_key: false,
1145                    ordinal_position: 0,
1146                    schema_name: "public".to_string(),
1147                    column_default: None,
1148                }],
1149            }],
1150            enums: vec![EnumInfo {
1151                schema_name: "public".to_string(),
1152                name: "task_status".to_string(),
1153                variants: vec!["idle".to_string()],
1154                default_variant: None,
1155            }],
1156            ..Default::default()
1157        };
1158        let defaults = extract_enum_defaults(&schema);
1159        assert!(defaults.is_empty());
1160    }
1161
1162    #[test]
1163    fn test_extract_enum_defaults_non_enum_column_ignored() {
1164        let schema = SchemaInfo {
1165            tables: vec![TableInfo {
1166                schema_name: "public".to_string(),
1167                name: "users".to_string(),
1168                columns: vec![ColumnInfo {
1169                    name: "name".to_string(),
1170                    data_type: "character varying".to_string(),
1171                    udt_name: "varchar".to_string(),
1172                    is_nullable: false,
1173                    is_primary_key: false,
1174                    ordinal_position: 0,
1175                    schema_name: "public".to_string(),
1176                    column_default: Some("'hello'::character varying".to_string()),
1177                }],
1178            }],
1179            enums: vec![],
1180            ..Default::default()
1181        };
1182        let defaults = extract_enum_defaults(&schema);
1183        assert!(defaults.is_empty());
1184    }
1185
1186    #[test]
1187    fn test_generate_enum_with_default() {
1188        let schema = SchemaInfo {
1189            tables: vec![TableInfo {
1190                schema_name: "public".to_string(),
1191                name: "tasks".to_string(),
1192                columns: vec![ColumnInfo {
1193                    name: "status".to_string(),
1194                    data_type: "USER-DEFINED".to_string(),
1195                    udt_name: "task_status".to_string(),
1196                    is_nullable: false,
1197                    is_primary_key: false,
1198                    ordinal_position: 0,
1199                    schema_name: "public".to_string(),
1200                    column_default: Some("'idle'::task_status".to_string()),
1201                }],
1202            }],
1203            enums: vec![EnumInfo {
1204                schema_name: "public".to_string(),
1205                name: "task_status".to_string(),
1206                variants: vec!["idle".to_string(), "running".to_string()],
1207                default_variant: None,
1208            }],
1209            ..Default::default()
1210        };
1211        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1212        let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1213        assert!(types_file.code.contains("impl Default for TaskStatus"));
1214        assert!(types_file.code.contains("Self::Idle"));
1215    }
1216}