Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15/// Rust reserved keywords that cannot be used as identifiers.
16const RUST_KEYWORDS: &[&str] = &[
17    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24/// Returns true if the given name is a Rust reserved keyword.
25pub fn is_rust_keyword(name: &str) -> bool {
26    RUST_KEYWORDS.contains(&name)
27}
28
29/// Returns the imports needed for well-known extra derives.
30pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31    let mut imports = Vec::new();
32    let has = |name: &str| extra_derives.iter().any(|d| d == name);
33    if has("Serialize") || has("Deserialize") {
34        let mut parts = Vec::new();
35        if has("Serialize") {
36            parts.push("Serialize");
37        }
38        if has("Deserialize") {
39            parts.push("Deserialize");
40        }
41        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42    }
43    imports
44}
45
46/// Normalize a table name for use as a Rust module/filename:
47/// replace multiple consecutive underscores with a single one.
48pub fn normalize_module_name(name: &str) -> String {
49    let mut result = String::with_capacity(name.len());
50    let mut prev_underscore = false;
51    for c in name.chars() {
52        if c == '_' {
53            if !prev_underscore {
54                result.push(c);
55            }
56            prev_underscore = true;
57        } else {
58            prev_underscore = false;
59            result.push(c);
60        }
61    }
62    result
63}
64
65/// Well-known default schemas that don't need a prefix in filenames.
66const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68/// Returns true if the schema is a well-known default (public, main, dbo).
69pub fn is_default_schema(schema: &str) -> bool {
70    DEFAULT_SCHEMAS.contains(&schema)
71}
72
73/// Build a module name, prefixing with schema when there are multiple schemas
74/// and the schema is not a well-known default.
75pub fn build_module_name(schema_name: &str, table_name: &str, has_multiple_schemas: bool) -> String {
76    if !has_multiple_schemas || DEFAULT_SCHEMAS.contains(&schema_name) {
77        normalize_module_name(table_name)
78    } else {
79        normalize_module_name(&format!("{}_{}", schema_name, table_name))
80    }
81}
82
83/// A generated code file with its content and required imports.
84#[derive(Debug, Clone)]
85pub struct GeneratedFile {
86    pub filename: String,
87    /// Optional origin comment (e.g. "Table: schema.name")
88    pub origin: Option<String>,
89    pub code: String,
90}
91
92/// Generate all code for a given schema.
93pub fn generate(
94    schema_info: &SchemaInfo,
95    db_kind: DatabaseKind,
96    extra_derives: &[String],
97    type_overrides: &HashMap<String, String>,
98    single_file: bool,
99) -> Vec<GeneratedFile> {
100    let mut files = Vec::new();
101
102    // Detect if multiple schemas are present
103    let mut schemas = BTreeSet::new();
104    for t in &schema_info.tables {
105        schemas.insert(t.schema_name.as_str());
106    }
107    for v in &schema_info.views {
108        schemas.insert(v.schema_name.as_str());
109    }
110    let has_multiple_schemas = schemas.len() > 1;
111
112    // Generate struct files for each table
113    for table in &schema_info.tables {
114        let (tokens, imports) =
115            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
116        let imports = filter_imports(&imports, single_file);
117        let code = format_tokens_with_imports(&tokens, &imports);
118        let module_name = build_module_name(&table.schema_name, &table.name, has_multiple_schemas);
119        let origin = format!("Table: {}.{}", table.schema_name, table.name);
120        files.push(GeneratedFile {
121            filename: format!("{}.rs", module_name),
122            origin: Some(origin),
123            code,
124        });
125    }
126
127    // Generate struct files for each view
128    for view in &schema_info.views {
129        let (tokens, imports) =
130            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
131        let imports = filter_imports(&imports, single_file);
132        let code = format_tokens_with_imports(&tokens, &imports);
133        let module_name = build_module_name(&view.schema_name, &view.name, has_multiple_schemas);
134        let origin = format!("View: {}.{}", view.schema_name, view.name);
135        files.push(GeneratedFile {
136            filename: format!("{}.rs", module_name),
137            origin: Some(origin),
138            code,
139        });
140    }
141
142    // Generate types file (enums, composites, domains)
143    // Each item is formatted individually so we can insert blank lines between them.
144    let mut types_blocks: Vec<String> = Vec::new();
145    let mut types_imports = BTreeSet::new();
146
147    for enum_info in &schema_info.enums {
148        let (tokens, imports) = enum_gen::generate_enum(enum_info, db_kind, extra_derives);
149        types_blocks.push(format_tokens(&tokens));
150        types_imports.extend(imports);
151    }
152
153    for composite in &schema_info.composite_types {
154        let (tokens, imports) = composite_gen::generate_composite(
155            composite,
156            db_kind,
157            schema_info,
158            extra_derives,
159            type_overrides,
160        );
161        types_blocks.push(format_tokens(&tokens));
162        types_imports.extend(imports);
163    }
164
165    for domain in &schema_info.domains {
166        let (tokens, imports) =
167            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
168        types_blocks.push(format_tokens(&tokens));
169        types_imports.extend(imports);
170    }
171
172    if !types_blocks.is_empty() {
173        let import_lines: String = types_imports
174            .iter()
175            .map(|i| format!("{}\n", i))
176            .collect();
177        let body = types_blocks.join("\n");
178        let code = if import_lines.is_empty() {
179            body
180        } else {
181            format!("{}\n\n{}", import_lines.trim_end(), body)
182        };
183        files.push(GeneratedFile {
184            filename: "types.rs".to_string(),
185            origin: None,
186            code,
187        });
188    }
189
190    files
191}
192
193/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
194fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
195    if single_file {
196        imports
197            .iter()
198            .filter(|i| !i.contains("super::types::"))
199            .cloned()
200            .collect()
201    } else {
202        imports.clone()
203    }
204}
205
206/// Parse and format a TokenStream via prettyplease, then post-process spacing.
207pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
208    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
209        log::error!("Failed to parse generated code: {}", e);
210        log::error!("This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
211        std::process::exit(1);
212    });
213    let raw = prettyplease::unparse(&file);
214    add_blank_lines_between_items(&raw)
215}
216
217/// Format a single TokenStream block (no imports).
218pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
219    parse_and_format(tokens)
220}
221
222pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
223    let formatted = parse_and_format(tokens);
224
225    let used_imports: Vec<&String> = imports
226        .iter()
227        .filter(|imp| is_import_used(imp, &formatted))
228        .collect();
229
230    if used_imports.is_empty() {
231        formatted
232    } else {
233        let import_lines: String = used_imports
234            .iter()
235            .map(|i| format!("{}\n", i))
236            .collect();
237        format!("{}\n\n{}", import_lines.trim_end(), formatted)
238    }
239}
240
241/// Check if an import is actually used in the generated code.
242/// Extracts the imported type names and checks if they appear in the code.
243fn is_import_used(import: &str, code: &str) -> bool {
244    // "use foo::bar::Baz;" → check for "Baz"
245    // "use foo::{A, B};" → check for "A" or "B"
246    // "use foo::bar::*;" → always keep
247    let trimmed = import.trim().trim_end_matches(';');
248    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
249
250    if path.ends_with("::*") {
251        return true;
252    }
253
254    // Handle grouped imports: use foo::{A, B, C};
255    if let Some(start) = path.find('{') {
256        if let Some(end) = path.find('}') {
257            let names = &path[start + 1..end];
258            return names
259                .split(',')
260                .map(|n| n.trim())
261                .filter(|n| !n.is_empty())
262                .any(|name| code.contains(name));
263        }
264    }
265
266    // Simple import: use foo::Bar;
267    if let Some(name) = path.rsplit("::").next() {
268        return code.contains(name);
269    }
270
271    true
272}
273
274/// Post-process formatted code to:
275/// - Add blank lines between enum variants with `#[sqlx(rename`
276/// - Add blank lines between top-level items (structs, impls)
277/// - Add blank lines between logical blocks inside async methods
278fn add_blank_lines_between_items(code: &str) -> String {
279    let lines: Vec<&str> = code.lines().collect();
280    let mut result = Vec::with_capacity(lines.len());
281
282    for (i, line) in lines.iter().enumerate() {
283        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
284        // but not for the first variant in the enum.
285        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
286            let prev = lines[i - 1].trim();
287            if prev.ends_with(',') {
288                result.push("");
289            }
290        }
291
292        // Insert a blank line before top-level items (pub struct, impl, #[derive)
293        // and before methods inside impl blocks, when preceded by a closing brace `}`
294        if i > 0 {
295            let trimmed = line.trim();
296            let prev = lines[i - 1].trim();
297            if prev == "}"
298                && (trimmed.starts_with("pub struct")
299                    || trimmed.starts_with("impl ")
300                    || trimmed.starts_with("#[derive")
301                    || trimmed.starts_with("pub async fn")
302                    || trimmed.starts_with("pub fn"))
303            {
304                result.push("");
305            }
306        }
307
308        // Insert a blank line before a new logical block inside methods:
309        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
310        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
311        if i > 0 {
312            let trimmed = line.trim();
313            let prev = lines[i - 1].trim();
314            let prev_is_await_end = prev.ends_with(".await?;")
315                || prev.ends_with(".await?")
316                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
317            if prev_is_await_end
318                && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
319            {
320                result.push("");
321            }
322            // Separate a sqlx query `let` from preceding simple `let` assignments
323            if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
324                && prev.starts_with("let ") && !prev.contains("sqlx::")
325            {
326                result.push("");
327            }
328        }
329
330        result.push(line);
331    }
332
333    result.join("\n")
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::introspect::{
340        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
341    };
342    use std::collections::HashMap;
343
344    // ========== is_rust_keyword ==========
345
346    #[test]
347    fn test_keyword_type() {
348        assert!(is_rust_keyword("type"));
349    }
350
351    #[test]
352    fn test_keyword_fn() {
353        assert!(is_rust_keyword("fn"));
354    }
355
356    #[test]
357    fn test_keyword_let() {
358        assert!(is_rust_keyword("let"));
359    }
360
361    #[test]
362    fn test_keyword_match() {
363        assert!(is_rust_keyword("match"));
364    }
365
366    #[test]
367    fn test_keyword_async() {
368        assert!(is_rust_keyword("async"));
369    }
370
371    #[test]
372    fn test_keyword_await() {
373        assert!(is_rust_keyword("await"));
374    }
375
376    #[test]
377    fn test_keyword_yield() {
378        assert!(is_rust_keyword("yield"));
379    }
380
381    #[test]
382    fn test_keyword_abstract() {
383        assert!(is_rust_keyword("abstract"));
384    }
385
386    #[test]
387    fn test_keyword_try() {
388        assert!(is_rust_keyword("try"));
389    }
390
391    #[test]
392    fn test_not_keyword_name() {
393        assert!(!is_rust_keyword("name"));
394    }
395
396    #[test]
397    fn test_not_keyword_id() {
398        assert!(!is_rust_keyword("id"));
399    }
400
401    #[test]
402    fn test_not_keyword_uppercase_type() {
403        assert!(!is_rust_keyword("Type"));
404    }
405
406    // ========== normalize_module_name ==========
407
408    #[test]
409    fn test_normalize_no_underscores() {
410        assert_eq!(normalize_module_name("users"), "users");
411    }
412
413    #[test]
414    fn test_normalize_single_underscore() {
415        assert_eq!(normalize_module_name("user_roles"), "user_roles");
416    }
417
418    #[test]
419    fn test_normalize_double_underscore() {
420        assert_eq!(normalize_module_name("user__roles"), "user_roles");
421    }
422
423    #[test]
424    fn test_normalize_triple_underscore() {
425        assert_eq!(normalize_module_name("a___b"), "a_b");
426    }
427
428    #[test]
429    fn test_normalize_leading_underscore() {
430        assert_eq!(normalize_module_name("_private"), "_private");
431    }
432
433    #[test]
434    fn test_normalize_trailing_underscore() {
435        assert_eq!(normalize_module_name("name_"), "name_");
436    }
437
438    #[test]
439    fn test_normalize_double_leading() {
440        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
441    }
442
443    #[test]
444    fn test_normalize_multiple_groups() {
445        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
446    }
447
448    // ========== build_module_name ==========
449
450    #[test]
451    fn test_build_single_schema_no_prefix() {
452        assert_eq!(build_module_name("public", "users", false), "users");
453    }
454
455    #[test]
456    fn test_build_multi_schema_default_no_prefix() {
457        assert_eq!(build_module_name("public", "users", true), "users");
458    }
459
460    #[test]
461    fn test_build_multi_schema_non_default_prefixed() {
462        assert_eq!(build_module_name("billing", "users", true), "billing_users");
463    }
464
465    #[test]
466    fn test_build_multi_schema_dbo_no_prefix() {
467        assert_eq!(build_module_name("dbo", "users", true), "users");
468    }
469
470    #[test]
471    fn test_build_multi_schema_main_no_prefix() {
472        assert_eq!(build_module_name("main", "users", true), "users");
473    }
474
475    #[test]
476    fn test_build_normalizes_double_underscore() {
477        assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
478    }
479
480    // ========== is_default_schema ==========
481
482    #[test]
483    fn test_default_schema_public() {
484        assert!(is_default_schema("public"));
485    }
486
487    #[test]
488    fn test_default_schema_main() {
489        assert!(is_default_schema("main"));
490    }
491
492    #[test]
493    fn test_non_default_schema() {
494        assert!(!is_default_schema("billing"));
495    }
496
497    // ========== imports_for_derives ==========
498
499    #[test]
500    fn test_imports_empty() {
501        let result = imports_for_derives(&[]);
502        assert!(result.is_empty());
503    }
504
505    #[test]
506    fn test_imports_serialize_only() {
507        let derives = vec!["Serialize".to_string()];
508        let result = imports_for_derives(&derives);
509        assert_eq!(result, vec!["use serde::{Serialize};"]);
510    }
511
512    #[test]
513    fn test_imports_deserialize_only() {
514        let derives = vec!["Deserialize".to_string()];
515        let result = imports_for_derives(&derives);
516        assert_eq!(result, vec!["use serde::{Deserialize};"]);
517    }
518
519    #[test]
520    fn test_imports_both_serde() {
521        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
522        let result = imports_for_derives(&derives);
523        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
524    }
525
526    #[test]
527    fn test_imports_non_serde() {
528        let derives = vec!["Hash".to_string()];
529        let result = imports_for_derives(&derives);
530        assert!(result.is_empty());
531    }
532
533    #[test]
534    fn test_imports_non_serde_multiple() {
535        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
536        let result = imports_for_derives(&derives);
537        assert!(result.is_empty());
538    }
539
540    #[test]
541    fn test_imports_mixed_serde_and_others() {
542        let derives = vec![
543            "Serialize".to_string(),
544            "Hash".to_string(),
545            "Deserialize".to_string(),
546        ];
547        let result = imports_for_derives(&derives);
548        assert_eq!(result.len(), 1);
549        assert!(result[0].contains("Serialize"));
550        assert!(result[0].contains("Deserialize"));
551    }
552
553    // ========== add_blank_lines_between_items ==========
554
555    #[test]
556    fn test_blank_lines_between_renamed_variants() {
557        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
558        let result = add_blank_lines_between_items(input);
559        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
560    }
561
562    #[test]
563    fn test_no_blank_line_for_first_variant() {
564        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
565        let result = add_blank_lines_between_items(input);
566        // No blank line before first #[sqlx(rename because previous line is `{`
567        assert!(!result.contains("{\n\n"));
568    }
569
570    #[test]
571    fn test_no_change_without_rename() {
572        let input = "pub enum Foo {\n    A,\n    B,\n}";
573        let result = add_blank_lines_between_items(input);
574        assert_eq!(result, input);
575    }
576
577    #[test]
578    fn test_no_change_for_struct() {
579        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
580        let result = add_blank_lines_between_items(input);
581        assert_eq!(result, input);
582    }
583
584    // ========== filter_imports ==========
585
586    #[test]
587    fn test_filter_single_file_strips_super_types() {
588        let mut imports = BTreeSet::new();
589        imports.insert("use super::types::Foo;".to_string());
590        imports.insert("use chrono::NaiveDateTime;".to_string());
591        let result = filter_imports(&imports, true);
592        assert!(!result.contains("use super::types::Foo;"));
593        assert!(result.contains("use chrono::NaiveDateTime;"));
594    }
595
596    #[test]
597    fn test_filter_single_file_keeps_other_imports() {
598        let mut imports = BTreeSet::new();
599        imports.insert("use chrono::NaiveDateTime;".to_string());
600        let result = filter_imports(&imports, true);
601        assert!(result.contains("use chrono::NaiveDateTime;"));
602    }
603
604    #[test]
605    fn test_filter_multi_file_keeps_all() {
606        let mut imports = BTreeSet::new();
607        imports.insert("use super::types::Foo;".to_string());
608        imports.insert("use chrono::NaiveDateTime;".to_string());
609        let result = filter_imports(&imports, false);
610        assert_eq!(result.len(), 2);
611    }
612
613    #[test]
614    fn test_filter_empty_set() {
615        let imports = BTreeSet::new();
616        let result = filter_imports(&imports, true);
617        assert!(result.is_empty());
618    }
619
620    // ========== generate() orchestrator ==========
621
622    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
623        TableInfo {
624            schema_name: "public".to_string(),
625            name: name.to_string(),
626            columns,
627        }
628    }
629
630    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
631        ColumnInfo {
632            name: name.to_string(),
633            data_type: udt_name.to_string(),
634            udt_name: udt_name.to_string(),
635            is_nullable: false,
636            is_primary_key: false,
637            ordinal_position: 0,
638            schema_name: "public".to_string(),
639        }
640    }
641
642    #[test]
643    fn test_generate_empty_schema() {
644        let schema = SchemaInfo::default();
645        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
646        assert!(files.is_empty());
647    }
648
649    #[test]
650    fn test_generate_one_table() {
651        let schema = SchemaInfo {
652            tables: vec![make_table("users", vec![make_col("id", "int4")])],
653            ..Default::default()
654        };
655        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
656        assert_eq!(files.len(), 1);
657        assert_eq!(files[0].filename, "users.rs");
658    }
659
660    #[test]
661    fn test_generate_two_tables() {
662        let schema = SchemaInfo {
663            tables: vec![
664                make_table("users", vec![make_col("id", "int4")]),
665                make_table("posts", vec![make_col("id", "int4")]),
666            ],
667            ..Default::default()
668        };
669        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
670        assert_eq!(files.len(), 2);
671    }
672
673    #[test]
674    fn test_generate_enum_creates_types_file() {
675        let schema = SchemaInfo {
676            enums: vec![EnumInfo {
677                schema_name: "public".to_string(),
678                name: "status".to_string(),
679                variants: vec!["active".to_string(), "inactive".to_string()],
680            }],
681            ..Default::default()
682        };
683        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
684        assert_eq!(files.len(), 1);
685        assert_eq!(files[0].filename, "types.rs");
686    }
687
688    #[test]
689    fn test_generate_enums_composites_domains_single_types_file() {
690        let schema = SchemaInfo {
691            enums: vec![EnumInfo {
692                schema_name: "public".to_string(),
693                name: "status".to_string(),
694                variants: vec!["active".to_string()],
695            }],
696            composite_types: vec![CompositeTypeInfo {
697                schema_name: "public".to_string(),
698                name: "address".to_string(),
699                fields: vec![make_col("street", "text")],
700            }],
701            domains: vec![DomainInfo {
702                schema_name: "public".to_string(),
703                name: "email".to_string(),
704                base_type: "text".to_string(),
705            }],
706            ..Default::default()
707        };
708        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
709        // Should produce exactly 1 types.rs
710        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
711        assert_eq!(types_files.len(), 1);
712    }
713
714    #[test]
715    fn test_generate_tables_and_enums() {
716        let schema = SchemaInfo {
717            tables: vec![make_table("users", vec![make_col("id", "int4")])],
718            enums: vec![EnumInfo {
719                schema_name: "public".to_string(),
720                name: "status".to_string(),
721                variants: vec!["active".to_string()],
722            }],
723            ..Default::default()
724        };
725        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
726        assert_eq!(files.len(), 2); // users.rs + types.rs
727    }
728
729    #[test]
730    fn test_generate_filename_normalized() {
731        let schema = SchemaInfo {
732            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
733            ..Default::default()
734        };
735        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
736        assert_eq!(files[0].filename, "user_data.rs");
737    }
738
739    #[test]
740    fn test_generate_origin_correct() {
741        let schema = SchemaInfo {
742            tables: vec![make_table("users", vec![make_col("id", "int4")])],
743            ..Default::default()
744        };
745        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
746        assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
747    }
748
749    #[test]
750    fn test_generate_types_no_origin() {
751        let schema = SchemaInfo {
752            enums: vec![EnumInfo {
753                schema_name: "public".to_string(),
754                name: "status".to_string(),
755                variants: vec!["active".to_string()],
756            }],
757            ..Default::default()
758        };
759        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
760        assert_eq!(files[0].origin, None);
761    }
762
763    #[test]
764    fn test_generate_single_file_filters_super_types_imports() {
765        let schema = SchemaInfo {
766            tables: vec![make_table("users", vec![make_col("id", "int4")])],
767            enums: vec![EnumInfo {
768                schema_name: "public".to_string(),
769                name: "status".to_string(),
770                variants: vec!["active".to_string()],
771            }],
772            ..Default::default()
773        };
774        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
775        // struct file should not have super::types:: imports
776        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
777        assert!(!struct_file.code.contains("super::types::"));
778    }
779
780    #[test]
781    fn test_generate_multi_file_keeps_super_types_imports() {
782        // Table with a column referencing an enum
783        let schema = SchemaInfo {
784            tables: vec![make_table("users", vec![make_col("status", "status")])],
785            enums: vec![EnumInfo {
786                schema_name: "public".to_string(),
787                name: "status".to_string(),
788                variants: vec!["active".to_string()],
789            }],
790            ..Default::default()
791        };
792        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
793        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
794        assert!(struct_file.code.contains("super::types::"));
795    }
796
797    #[test]
798    fn test_generate_extra_derives_in_struct() {
799        let schema = SchemaInfo {
800            tables: vec![make_table("users", vec![make_col("id", "int4")])],
801            ..Default::default()
802        };
803        let derives = vec!["Serialize".to_string()];
804        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
805        assert!(files[0].code.contains("Serialize"));
806    }
807
808    #[test]
809    fn test_generate_extra_derives_in_enum() {
810        let schema = SchemaInfo {
811            enums: vec![EnumInfo {
812                schema_name: "public".to_string(),
813                name: "status".to_string(),
814                variants: vec!["active".to_string()],
815            }],
816            ..Default::default()
817        };
818        let derives = vec!["Serialize".to_string()];
819        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
820        assert!(files[0].code.contains("Serialize"));
821    }
822
823    #[test]
824    fn test_generate_type_overrides_in_struct() {
825        let mut overrides = HashMap::new();
826        overrides.insert("jsonb".to_string(), "MyJson".to_string());
827        let schema = SchemaInfo {
828            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
829            ..Default::default()
830        };
831        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
832        assert!(files[0].code.contains("MyJson"));
833    }
834
835    #[test]
836    fn test_generate_valid_rust_syntax() {
837        let schema = SchemaInfo {
838            tables: vec![make_table("users", vec![
839                make_col("id", "int4"),
840                make_col("name", "text"),
841            ])],
842            enums: vec![EnumInfo {
843                schema_name: "public".to_string(),
844                name: "status".to_string(),
845                variants: vec!["active".to_string(), "inactive".to_string()],
846            }],
847            ..Default::default()
848        };
849        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
850        for f in &files {
851            // Should be parseable as valid Rust
852            let parse_result = syn::parse_file(&f.code);
853            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
854        }
855    }
856
857    // ========== generate() — views ==========
858
859    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
860        TableInfo {
861            schema_name: "public".to_string(),
862            name: name.to_string(),
863            columns,
864        }
865    }
866
867    #[test]
868    fn test_generate_one_view() {
869        let schema = SchemaInfo {
870            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
871            ..Default::default()
872        };
873        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
874        assert_eq!(files.len(), 1);
875        assert_eq!(files[0].filename, "active_users.rs");
876    }
877
878    #[test]
879    fn test_generate_view_origin() {
880        let schema = SchemaInfo {
881            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
882            ..Default::default()
883        };
884        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
885        assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
886    }
887
888    #[test]
889    fn test_generate_tables_and_views() {
890        let schema = SchemaInfo {
891            tables: vec![make_table("users", vec![make_col("id", "int4")])],
892            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
893            ..Default::default()
894        };
895        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
896        assert_eq!(files.len(), 2);
897    }
898
899    #[test]
900    fn test_generate_view_valid_rust() {
901        let schema = SchemaInfo {
902            views: vec![make_view("active_users", vec![
903                make_col("id", "int4"),
904                make_col("name", "text"),
905            ])],
906            ..Default::default()
907        };
908        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
909        let parse_result = syn::parse_file(&files[0].code);
910        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
911    }
912
913    #[test]
914    fn test_generate_view_nullable_column() {
915        let schema = SchemaInfo {
916            views: vec![make_view("v", vec![ColumnInfo {
917                name: "email".to_string(),
918                data_type: "text".to_string(),
919                udt_name: "text".to_string(),
920                is_nullable: true,
921                is_primary_key: false,
922                ordinal_position: 0,
923                schema_name: "public".to_string(),
924            }])],
925            ..Default::default()
926        };
927        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
928        assert!(files[0].code.contains("Option<String>"));
929    }
930
931    #[test]
932    fn test_generate_multi_schema_prefixes_non_default() {
933        let schema = SchemaInfo {
934            tables: vec![
935                make_table("users", vec![make_col("id", "int4")]),
936                TableInfo {
937                    schema_name: "billing".to_string(),
938                    name: "users".to_string(),
939                    columns: vec![make_col("id", "int4")],
940                },
941            ],
942            ..Default::default()
943        };
944        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
945        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
946        assert!(filenames.contains(&"users.rs"));
947        assert!(filenames.contains(&"billing_users.rs"));
948    }
949
950    #[test]
951    fn test_generate_single_schema_no_prefix() {
952        let schema = SchemaInfo {
953            tables: vec![
954                make_table("users", vec![make_col("id", "int4")]),
955                make_table("posts", vec![make_col("id", "int4")]),
956            ],
957            ..Default::default()
958        };
959        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
960        assert_eq!(files[0].filename, "users.rs");
961        assert_eq!(files[1].filename, "posts.rs");
962    }
963
964    #[test]
965    fn test_generate_view_single_file_mode() {
966        let schema = SchemaInfo {
967            tables: vec![make_table("users", vec![make_col("id", "int4")])],
968            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
969            ..Default::default()
970        };
971        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
972        assert_eq!(files.len(), 2);
973    }
974}