Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15/// Rust reserved keywords that cannot be used as identifiers.
16const RUST_KEYWORDS: &[&str] = &[
17    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24/// Returns true if the given name is a Rust reserved keyword.
25pub fn is_rust_keyword(name: &str) -> bool {
26    RUST_KEYWORDS.contains(&name)
27}
28
29/// Returns the imports needed for well-known extra derives.
30pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31    let mut imports = Vec::new();
32    let has = |name: &str| extra_derives.iter().any(|d| d == name);
33    if has("Serialize") || has("Deserialize") {
34        let mut parts = Vec::new();
35        if has("Serialize") {
36            parts.push("Serialize");
37        }
38        if has("Deserialize") {
39            parts.push("Deserialize");
40        }
41        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42    }
43    imports
44}
45
46/// Normalize a table name for use as a Rust module/filename:
47/// replace multiple consecutive underscores with a single one.
48pub fn normalize_module_name(name: &str) -> String {
49    let mut result = String::with_capacity(name.len());
50    let mut prev_underscore = false;
51    for c in name.chars() {
52        if c == '_' {
53            if !prev_underscore {
54                result.push(c);
55            }
56            prev_underscore = true;
57        } else {
58            prev_underscore = false;
59            result.push(c);
60        }
61    }
62    result
63}
64
65/// Well-known default schemas that don't need a prefix in filenames.
66const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68/// Returns true if the schema is a well-known default (public, main, dbo).
69pub fn is_default_schema(schema: &str) -> bool {
70    DEFAULT_SCHEMAS.contains(&schema)
71}
72
73/// Build a module name, prefixing with schema only when the name collides
74/// (same table name exists in multiple schemas).
75pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76    if name_collides && !is_default_schema(schema_name) {
77        normalize_module_name(&format!("{}_{}", schema_name, table_name))
78    } else {
79        normalize_module_name(table_name)
80    }
81}
82
83/// Find table/view names that appear in more than one schema.
84fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86    for t in &schema_info.tables {
87        seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88    }
89    for v in &schema_info.views {
90        seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91    }
92    seen.into_iter()
93        .filter(|(_, schemas)| schemas.len() > 1)
94        .map(|(name, _)| name)
95        .collect()
96}
97
98/// A generated code file with its content and required imports.
99#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101    pub filename: String,
102    /// Optional origin comment (e.g. "Table: schema.name")
103    pub origin: Option<String>,
104    pub code: String,
105}
106
107/// Generate all code for a given schema.
108pub fn generate(
109    schema_info: &SchemaInfo,
110    db_kind: DatabaseKind,
111    extra_derives: &[String],
112    type_overrides: &HashMap<String, String>,
113    single_file: bool,
114) -> Vec<GeneratedFile> {
115    let mut files = Vec::new();
116
117    // Detect table/view names that appear in multiple schemas (collisions)
118    let colliding_names = find_colliding_names(schema_info);
119
120    // Generate struct files for each table
121    for table in &schema_info.tables {
122        let (tokens, imports) =
123            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124        let imports = filter_imports(&imports, single_file);
125        let code = format_tokens_with_imports(&tokens, &imports);
126        let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127        let origin = format!("Table: {}.{}", table.schema_name, table.name);
128        files.push(GeneratedFile {
129            filename: format!("{}.rs", module_name),
130            origin: Some(origin),
131            code,
132        });
133    }
134
135    // Generate struct files for each view
136    for view in &schema_info.views {
137        let (tokens, imports) =
138            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
139        let imports = filter_imports(&imports, single_file);
140        let code = format_tokens_with_imports(&tokens, &imports);
141        let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
142        let origin = format!("View: {}.{}", view.schema_name, view.name);
143        files.push(GeneratedFile {
144            filename: format!("{}.rs", module_name),
145            origin: Some(origin),
146            code,
147        });
148    }
149
150    // Generate types file (enums, composites, domains)
151    // Each item is formatted individually so we can insert blank lines between them.
152    let mut types_blocks: Vec<String> = Vec::new();
153    let mut types_imports = BTreeSet::new();
154
155    for enum_info in &schema_info.enums {
156        let (tokens, imports) = enum_gen::generate_enum(enum_info, db_kind, extra_derives);
157        types_blocks.push(format_tokens(&tokens));
158        types_imports.extend(imports);
159    }
160
161    for composite in &schema_info.composite_types {
162        let (tokens, imports) = composite_gen::generate_composite(
163            composite,
164            db_kind,
165            schema_info,
166            extra_derives,
167            type_overrides,
168        );
169        types_blocks.push(format_tokens(&tokens));
170        types_imports.extend(imports);
171    }
172
173    for domain in &schema_info.domains {
174        let (tokens, imports) =
175            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
176        types_blocks.push(format_tokens(&tokens));
177        types_imports.extend(imports);
178    }
179
180    if !types_blocks.is_empty() {
181        let import_lines: String = types_imports
182            .iter()
183            .map(|i| format!("{}\n", i))
184            .collect();
185        let body = types_blocks.join("\n");
186        let code = if import_lines.is_empty() {
187            body
188        } else {
189            format!("{}\n\n{}", import_lines.trim_end(), body)
190        };
191        files.push(GeneratedFile {
192            filename: "types.rs".to_string(),
193            origin: None,
194            code,
195        });
196    }
197
198    files
199}
200
201/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
202fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
203    if single_file {
204        imports
205            .iter()
206            .filter(|i| !i.contains("super::types::"))
207            .cloned()
208            .collect()
209    } else {
210        imports.clone()
211    }
212}
213
214/// Parse and format a TokenStream via prettyplease, then post-process spacing.
215pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
216    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
217        log::error!("Failed to parse generated code: {}", e);
218        log::error!("This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
219        std::process::exit(1);
220    });
221    let raw = prettyplease::unparse(&file);
222    add_blank_lines_between_items(&raw)
223}
224
225/// Format a single TokenStream block (no imports).
226pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
227    parse_and_format(tokens)
228}
229
230pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
231    let formatted = parse_and_format(tokens);
232
233    let used_imports: Vec<&String> = imports
234        .iter()
235        .filter(|imp| is_import_used(imp, &formatted))
236        .collect();
237
238    if used_imports.is_empty() {
239        formatted
240    } else {
241        let import_lines: String = used_imports
242            .iter()
243            .map(|i| format!("{}\n", i))
244            .collect();
245        format!("{}\n\n{}", import_lines.trim_end(), formatted)
246    }
247}
248
249/// Check if an import is actually used in the generated code.
250/// Extracts the imported type names and checks if they appear in the code.
251fn is_import_used(import: &str, code: &str) -> bool {
252    // "use foo::bar::Baz;" → check for "Baz"
253    // "use foo::{A, B};" → check for "A" or "B"
254    // "use foo::bar::*;" → always keep
255    let trimmed = import.trim().trim_end_matches(';');
256    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
257
258    if path.ends_with("::*") {
259        return true;
260    }
261
262    // Handle grouped imports: use foo::{A, B, C};
263    if let Some(start) = path.find('{') {
264        if let Some(end) = path.find('}') {
265            let names = &path[start + 1..end];
266            return names
267                .split(',')
268                .map(|n| n.trim())
269                .filter(|n| !n.is_empty())
270                .any(|name| code.contains(name));
271        }
272    }
273
274    // Simple import: use foo::Bar;
275    if let Some(name) = path.rsplit("::").next() {
276        return code.contains(name);
277    }
278
279    true
280}
281
282/// Post-process formatted code to:
283/// - Add blank lines between enum variants with `#[sqlx(rename`
284/// - Add blank lines between top-level items (structs, impls)
285/// - Add blank lines between logical blocks inside async methods
286fn add_blank_lines_between_items(code: &str) -> String {
287    let lines: Vec<&str> = code.lines().collect();
288    let mut result = Vec::with_capacity(lines.len());
289
290    for (i, line) in lines.iter().enumerate() {
291        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
292        // but not for the first variant in the enum.
293        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
294            let prev = lines[i - 1].trim();
295            if prev.ends_with(',') {
296                result.push("");
297            }
298        }
299
300        // Insert a blank line before top-level items (pub struct, impl, #[derive)
301        // and before methods inside impl blocks, when preceded by a closing brace `}`
302        if i > 0 {
303            let trimmed = line.trim();
304            let prev = lines[i - 1].trim();
305            if prev == "}"
306                && (trimmed.starts_with("pub struct")
307                    || trimmed.starts_with("impl ")
308                    || trimmed.starts_with("#[derive")
309                    || trimmed.starts_with("pub async fn")
310                    || trimmed.starts_with("pub fn"))
311            {
312                result.push("");
313            }
314        }
315
316        // Insert a blank line before a new logical block inside methods:
317        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
318        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
319        if i > 0 {
320            let trimmed = line.trim();
321            let prev = lines[i - 1].trim();
322            let prev_is_await_end = prev.ends_with(".await?;")
323                || prev.ends_with(".await?")
324                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
325            if prev_is_await_end
326                && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
327            {
328                result.push("");
329            }
330            // Separate a sqlx query `let` from preceding simple `let` assignments
331            if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
332                && prev.starts_with("let ") && !prev.contains("sqlx::")
333            {
334                result.push("");
335            }
336        }
337
338        result.push(line);
339    }
340
341    result.join("\n")
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::introspect::{
348        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
349    };
350    use std::collections::HashMap;
351
352    // ========== is_rust_keyword ==========
353
354    #[test]
355    fn test_keyword_type() {
356        assert!(is_rust_keyword("type"));
357    }
358
359    #[test]
360    fn test_keyword_fn() {
361        assert!(is_rust_keyword("fn"));
362    }
363
364    #[test]
365    fn test_keyword_let() {
366        assert!(is_rust_keyword("let"));
367    }
368
369    #[test]
370    fn test_keyword_match() {
371        assert!(is_rust_keyword("match"));
372    }
373
374    #[test]
375    fn test_keyword_async() {
376        assert!(is_rust_keyword("async"));
377    }
378
379    #[test]
380    fn test_keyword_await() {
381        assert!(is_rust_keyword("await"));
382    }
383
384    #[test]
385    fn test_keyword_yield() {
386        assert!(is_rust_keyword("yield"));
387    }
388
389    #[test]
390    fn test_keyword_abstract() {
391        assert!(is_rust_keyword("abstract"));
392    }
393
394    #[test]
395    fn test_keyword_try() {
396        assert!(is_rust_keyword("try"));
397    }
398
399    #[test]
400    fn test_not_keyword_name() {
401        assert!(!is_rust_keyword("name"));
402    }
403
404    #[test]
405    fn test_not_keyword_id() {
406        assert!(!is_rust_keyword("id"));
407    }
408
409    #[test]
410    fn test_not_keyword_uppercase_type() {
411        assert!(!is_rust_keyword("Type"));
412    }
413
414    // ========== normalize_module_name ==========
415
416    #[test]
417    fn test_normalize_no_underscores() {
418        assert_eq!(normalize_module_name("users"), "users");
419    }
420
421    #[test]
422    fn test_normalize_single_underscore() {
423        assert_eq!(normalize_module_name("user_roles"), "user_roles");
424    }
425
426    #[test]
427    fn test_normalize_double_underscore() {
428        assert_eq!(normalize_module_name("user__roles"), "user_roles");
429    }
430
431    #[test]
432    fn test_normalize_triple_underscore() {
433        assert_eq!(normalize_module_name("a___b"), "a_b");
434    }
435
436    #[test]
437    fn test_normalize_leading_underscore() {
438        assert_eq!(normalize_module_name("_private"), "_private");
439    }
440
441    #[test]
442    fn test_normalize_trailing_underscore() {
443        assert_eq!(normalize_module_name("name_"), "name_");
444    }
445
446    #[test]
447    fn test_normalize_double_leading() {
448        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
449    }
450
451    #[test]
452    fn test_normalize_multiple_groups() {
453        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
454    }
455
456    // ========== build_module_name ==========
457
458    #[test]
459    fn test_build_no_collision_no_prefix() {
460        assert_eq!(build_module_name("public", "users", false), "users");
461    }
462
463    #[test]
464    fn test_build_no_collision_non_default_no_prefix() {
465        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
466    }
467
468    #[test]
469    fn test_build_collision_prefixed() {
470        assert_eq!(build_module_name("billing", "users", true), "billing_users");
471    }
472
473    #[test]
474    fn test_build_collision_default_schema_no_prefix() {
475        assert_eq!(build_module_name("public", "users", true), "users");
476    }
477
478    #[test]
479    fn test_build_collision_normalizes_double_underscore() {
480        assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
481    }
482
483    // ========== is_default_schema ==========
484
485    #[test]
486    fn test_default_schema_public() {
487        assert!(is_default_schema("public"));
488    }
489
490    #[test]
491    fn test_default_schema_main() {
492        assert!(is_default_schema("main"));
493    }
494
495    #[test]
496    fn test_non_default_schema() {
497        assert!(!is_default_schema("billing"));
498    }
499
500    // ========== imports_for_derives ==========
501
502    #[test]
503    fn test_imports_empty() {
504        let result = imports_for_derives(&[]);
505        assert!(result.is_empty());
506    }
507
508    #[test]
509    fn test_imports_serialize_only() {
510        let derives = vec!["Serialize".to_string()];
511        let result = imports_for_derives(&derives);
512        assert_eq!(result, vec!["use serde::{Serialize};"]);
513    }
514
515    #[test]
516    fn test_imports_deserialize_only() {
517        let derives = vec!["Deserialize".to_string()];
518        let result = imports_for_derives(&derives);
519        assert_eq!(result, vec!["use serde::{Deserialize};"]);
520    }
521
522    #[test]
523    fn test_imports_both_serde() {
524        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
525        let result = imports_for_derives(&derives);
526        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
527    }
528
529    #[test]
530    fn test_imports_non_serde() {
531        let derives = vec!["Hash".to_string()];
532        let result = imports_for_derives(&derives);
533        assert!(result.is_empty());
534    }
535
536    #[test]
537    fn test_imports_non_serde_multiple() {
538        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
539        let result = imports_for_derives(&derives);
540        assert!(result.is_empty());
541    }
542
543    #[test]
544    fn test_imports_mixed_serde_and_others() {
545        let derives = vec![
546            "Serialize".to_string(),
547            "Hash".to_string(),
548            "Deserialize".to_string(),
549        ];
550        let result = imports_for_derives(&derives);
551        assert_eq!(result.len(), 1);
552        assert!(result[0].contains("Serialize"));
553        assert!(result[0].contains("Deserialize"));
554    }
555
556    // ========== add_blank_lines_between_items ==========
557
558    #[test]
559    fn test_blank_lines_between_renamed_variants() {
560        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
561        let result = add_blank_lines_between_items(input);
562        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
563    }
564
565    #[test]
566    fn test_no_blank_line_for_first_variant() {
567        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
568        let result = add_blank_lines_between_items(input);
569        // No blank line before first #[sqlx(rename because previous line is `{`
570        assert!(!result.contains("{\n\n"));
571    }
572
573    #[test]
574    fn test_no_change_without_rename() {
575        let input = "pub enum Foo {\n    A,\n    B,\n}";
576        let result = add_blank_lines_between_items(input);
577        assert_eq!(result, input);
578    }
579
580    #[test]
581    fn test_no_change_for_struct() {
582        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
583        let result = add_blank_lines_between_items(input);
584        assert_eq!(result, input);
585    }
586
587    // ========== filter_imports ==========
588
589    #[test]
590    fn test_filter_single_file_strips_super_types() {
591        let mut imports = BTreeSet::new();
592        imports.insert("use super::types::Foo;".to_string());
593        imports.insert("use chrono::NaiveDateTime;".to_string());
594        let result = filter_imports(&imports, true);
595        assert!(!result.contains("use super::types::Foo;"));
596        assert!(result.contains("use chrono::NaiveDateTime;"));
597    }
598
599    #[test]
600    fn test_filter_single_file_keeps_other_imports() {
601        let mut imports = BTreeSet::new();
602        imports.insert("use chrono::NaiveDateTime;".to_string());
603        let result = filter_imports(&imports, true);
604        assert!(result.contains("use chrono::NaiveDateTime;"));
605    }
606
607    #[test]
608    fn test_filter_multi_file_keeps_all() {
609        let mut imports = BTreeSet::new();
610        imports.insert("use super::types::Foo;".to_string());
611        imports.insert("use chrono::NaiveDateTime;".to_string());
612        let result = filter_imports(&imports, false);
613        assert_eq!(result.len(), 2);
614    }
615
616    #[test]
617    fn test_filter_empty_set() {
618        let imports = BTreeSet::new();
619        let result = filter_imports(&imports, true);
620        assert!(result.is_empty());
621    }
622
623    // ========== generate() orchestrator ==========
624
625    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
626        TableInfo {
627            schema_name: "public".to_string(),
628            name: name.to_string(),
629            columns,
630        }
631    }
632
633    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
634        ColumnInfo {
635            name: name.to_string(),
636            data_type: udt_name.to_string(),
637            udt_name: udt_name.to_string(),
638            is_nullable: false,
639            is_primary_key: false,
640            ordinal_position: 0,
641            schema_name: "public".to_string(),
642        }
643    }
644
645    #[test]
646    fn test_generate_empty_schema() {
647        let schema = SchemaInfo::default();
648        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
649        assert!(files.is_empty());
650    }
651
652    #[test]
653    fn test_generate_one_table() {
654        let schema = SchemaInfo {
655            tables: vec![make_table("users", vec![make_col("id", "int4")])],
656            ..Default::default()
657        };
658        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
659        assert_eq!(files.len(), 1);
660        assert_eq!(files[0].filename, "users.rs");
661    }
662
663    #[test]
664    fn test_generate_two_tables() {
665        let schema = SchemaInfo {
666            tables: vec![
667                make_table("users", vec![make_col("id", "int4")]),
668                make_table("posts", vec![make_col("id", "int4")]),
669            ],
670            ..Default::default()
671        };
672        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
673        assert_eq!(files.len(), 2);
674    }
675
676    #[test]
677    fn test_generate_enum_creates_types_file() {
678        let schema = SchemaInfo {
679            enums: vec![EnumInfo {
680                schema_name: "public".to_string(),
681                name: "status".to_string(),
682                variants: vec!["active".to_string(), "inactive".to_string()],
683            }],
684            ..Default::default()
685        };
686        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
687        assert_eq!(files.len(), 1);
688        assert_eq!(files[0].filename, "types.rs");
689    }
690
691    #[test]
692    fn test_generate_enums_composites_domains_single_types_file() {
693        let schema = SchemaInfo {
694            enums: vec![EnumInfo {
695                schema_name: "public".to_string(),
696                name: "status".to_string(),
697                variants: vec!["active".to_string()],
698            }],
699            composite_types: vec![CompositeTypeInfo {
700                schema_name: "public".to_string(),
701                name: "address".to_string(),
702                fields: vec![make_col("street", "text")],
703            }],
704            domains: vec![DomainInfo {
705                schema_name: "public".to_string(),
706                name: "email".to_string(),
707                base_type: "text".to_string(),
708            }],
709            ..Default::default()
710        };
711        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
712        // Should produce exactly 1 types.rs
713        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
714        assert_eq!(types_files.len(), 1);
715    }
716
717    #[test]
718    fn test_generate_tables_and_enums() {
719        let schema = SchemaInfo {
720            tables: vec![make_table("users", vec![make_col("id", "int4")])],
721            enums: vec![EnumInfo {
722                schema_name: "public".to_string(),
723                name: "status".to_string(),
724                variants: vec!["active".to_string()],
725            }],
726            ..Default::default()
727        };
728        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
729        assert_eq!(files.len(), 2); // users.rs + types.rs
730    }
731
732    #[test]
733    fn test_generate_filename_normalized() {
734        let schema = SchemaInfo {
735            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
736            ..Default::default()
737        };
738        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
739        assert_eq!(files[0].filename, "user_data.rs");
740    }
741
742    #[test]
743    fn test_generate_origin_correct() {
744        let schema = SchemaInfo {
745            tables: vec![make_table("users", vec![make_col("id", "int4")])],
746            ..Default::default()
747        };
748        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
749        assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
750    }
751
752    #[test]
753    fn test_generate_types_no_origin() {
754        let schema = SchemaInfo {
755            enums: vec![EnumInfo {
756                schema_name: "public".to_string(),
757                name: "status".to_string(),
758                variants: vec!["active".to_string()],
759            }],
760            ..Default::default()
761        };
762        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
763        assert_eq!(files[0].origin, None);
764    }
765
766    #[test]
767    fn test_generate_single_file_filters_super_types_imports() {
768        let schema = SchemaInfo {
769            tables: vec![make_table("users", vec![make_col("id", "int4")])],
770            enums: vec![EnumInfo {
771                schema_name: "public".to_string(),
772                name: "status".to_string(),
773                variants: vec!["active".to_string()],
774            }],
775            ..Default::default()
776        };
777        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
778        // struct file should not have super::types:: imports
779        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
780        assert!(!struct_file.code.contains("super::types::"));
781    }
782
783    #[test]
784    fn test_generate_multi_file_keeps_super_types_imports() {
785        // Table with a column referencing an enum
786        let schema = SchemaInfo {
787            tables: vec![make_table("users", vec![make_col("status", "status")])],
788            enums: vec![EnumInfo {
789                schema_name: "public".to_string(),
790                name: "status".to_string(),
791                variants: vec!["active".to_string()],
792            }],
793            ..Default::default()
794        };
795        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
796        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
797        assert!(struct_file.code.contains("super::types::"));
798    }
799
800    #[test]
801    fn test_generate_extra_derives_in_struct() {
802        let schema = SchemaInfo {
803            tables: vec![make_table("users", vec![make_col("id", "int4")])],
804            ..Default::default()
805        };
806        let derives = vec!["Serialize".to_string()];
807        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
808        assert!(files[0].code.contains("Serialize"));
809    }
810
811    #[test]
812    fn test_generate_extra_derives_in_enum() {
813        let schema = SchemaInfo {
814            enums: vec![EnumInfo {
815                schema_name: "public".to_string(),
816                name: "status".to_string(),
817                variants: vec!["active".to_string()],
818            }],
819            ..Default::default()
820        };
821        let derives = vec!["Serialize".to_string()];
822        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
823        assert!(files[0].code.contains("Serialize"));
824    }
825
826    #[test]
827    fn test_generate_type_overrides_in_struct() {
828        let mut overrides = HashMap::new();
829        overrides.insert("jsonb".to_string(), "MyJson".to_string());
830        let schema = SchemaInfo {
831            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
832            ..Default::default()
833        };
834        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
835        assert!(files[0].code.contains("MyJson"));
836    }
837
838    #[test]
839    fn test_generate_valid_rust_syntax() {
840        let schema = SchemaInfo {
841            tables: vec![make_table("users", vec![
842                make_col("id", "int4"),
843                make_col("name", "text"),
844            ])],
845            enums: vec![EnumInfo {
846                schema_name: "public".to_string(),
847                name: "status".to_string(),
848                variants: vec!["active".to_string(), "inactive".to_string()],
849            }],
850            ..Default::default()
851        };
852        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
853        for f in &files {
854            // Should be parseable as valid Rust
855            let parse_result = syn::parse_file(&f.code);
856            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
857        }
858    }
859
860    // ========== generate() — views ==========
861
862    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
863        TableInfo {
864            schema_name: "public".to_string(),
865            name: name.to_string(),
866            columns,
867        }
868    }
869
870    #[test]
871    fn test_generate_one_view() {
872        let schema = SchemaInfo {
873            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
874            ..Default::default()
875        };
876        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
877        assert_eq!(files.len(), 1);
878        assert_eq!(files[0].filename, "active_users.rs");
879    }
880
881    #[test]
882    fn test_generate_view_origin() {
883        let schema = SchemaInfo {
884            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
885            ..Default::default()
886        };
887        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
888        assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
889    }
890
891    #[test]
892    fn test_generate_tables_and_views() {
893        let schema = SchemaInfo {
894            tables: vec![make_table("users", vec![make_col("id", "int4")])],
895            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
896            ..Default::default()
897        };
898        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
899        assert_eq!(files.len(), 2);
900    }
901
902    #[test]
903    fn test_generate_view_valid_rust() {
904        let schema = SchemaInfo {
905            views: vec![make_view("active_users", vec![
906                make_col("id", "int4"),
907                make_col("name", "text"),
908            ])],
909            ..Default::default()
910        };
911        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
912        let parse_result = syn::parse_file(&files[0].code);
913        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
914    }
915
916    #[test]
917    fn test_generate_view_nullable_column() {
918        let schema = SchemaInfo {
919            views: vec![make_view("v", vec![ColumnInfo {
920                name: "email".to_string(),
921                data_type: "text".to_string(),
922                udt_name: "text".to_string(),
923                is_nullable: true,
924                is_primary_key: false,
925                ordinal_position: 0,
926                schema_name: "public".to_string(),
927            }])],
928            ..Default::default()
929        };
930        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
931        assert!(files[0].code.contains("Option<String>"));
932    }
933
934    #[test]
935    fn test_generate_collision_both_prefixed() {
936        let schema = SchemaInfo {
937            tables: vec![
938                make_table("users", vec![make_col("id", "int4")]),
939                TableInfo {
940                    schema_name: "billing".to_string(),
941                    name: "users".to_string(),
942                    columns: vec![make_col("id", "int4")],
943                },
944            ],
945            ..Default::default()
946        };
947        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
948        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
949        assert!(filenames.contains(&"users.rs"));
950        assert!(filenames.contains(&"billing_users.rs"));
951    }
952
953    #[test]
954    fn test_generate_no_collision_no_prefix() {
955        let schema = SchemaInfo {
956            tables: vec![
957                make_table("users", vec![make_col("id", "int4")]),
958                TableInfo {
959                    schema_name: "billing".to_string(),
960                    name: "invoices".to_string(),
961                    columns: vec![make_col("id", "int4")],
962                },
963            ],
964            ..Default::default()
965        };
966        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
967        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
968        assert!(filenames.contains(&"users.rs"));
969        assert!(filenames.contains(&"invoices.rs"));
970    }
971
972    #[test]
973    fn test_generate_single_schema_no_prefix() {
974        let schema = SchemaInfo {
975            tables: vec![
976                make_table("users", vec![make_col("id", "int4")]),
977                make_table("posts", vec![make_col("id", "int4")]),
978            ],
979            ..Default::default()
980        };
981        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
982        assert_eq!(files[0].filename, "users.rs");
983        assert_eq!(files[1].filename, "posts.rs");
984    }
985
986    #[test]
987    fn test_generate_view_single_file_mode() {
988        let schema = SchemaInfo {
989            tables: vec![make_table("users", vec![make_col("id", "int4")])],
990            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
991            ..Default::default()
992        };
993        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
994        assert_eq!(files.len(), 2);
995    }
996}