1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15const RUST_KEYWORDS: &[&str] = &[
17 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20 "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24pub fn is_rust_keyword(name: &str) -> bool {
26 RUST_KEYWORDS.contains(&name)
27}
28
29pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31 let mut imports = Vec::new();
32 let has = |name: &str| extra_derives.iter().any(|d| d == name);
33 if has("Serialize") || has("Deserialize") {
34 let mut parts = Vec::new();
35 if has("Serialize") {
36 parts.push("Serialize");
37 }
38 if has("Deserialize") {
39 parts.push("Deserialize");
40 }
41 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42 }
43 imports
44}
45
46pub fn normalize_module_name(name: &str) -> String {
49 let mut result = String::with_capacity(name.len());
50 let mut prev_underscore = false;
51 for c in name.chars() {
52 if c == '_' {
53 if !prev_underscore {
54 result.push(c);
55 }
56 prev_underscore = true;
57 } else {
58 prev_underscore = false;
59 result.push(c);
60 }
61 }
62 result
63}
64
65const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68pub fn is_default_schema(schema: &str) -> bool {
70 DEFAULT_SCHEMAS.contains(&schema)
71}
72
73pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76 if name_collides && !is_default_schema(schema_name) {
77 normalize_module_name(&format!("{}_{}", schema_name, table_name))
78 } else {
79 normalize_module_name(table_name)
80 }
81}
82
83fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86 for t in &schema_info.tables {
87 seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88 }
89 for v in &schema_info.views {
90 seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91 }
92 seen.into_iter()
93 .filter(|(_, schemas)| schemas.len() > 1)
94 .map(|(name, _)| name)
95 .collect()
96}
97
98#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101 pub filename: String,
102 pub origin: Option<String>,
104 pub code: String,
105}
106
107pub fn generate(
109 schema_info: &SchemaInfo,
110 db_kind: DatabaseKind,
111 extra_derives: &[String],
112 type_overrides: &HashMap<String, String>,
113 single_file: bool,
114) -> Vec<GeneratedFile> {
115 let mut files = Vec::new();
116
117 let colliding_names = find_colliding_names(schema_info);
119
120 for table in &schema_info.tables {
122 let (tokens, imports) =
123 struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124 let imports = filter_imports(&imports, single_file);
125 let code = format_tokens_with_imports(&tokens, &imports);
126 let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127 let origin = format!("Table: {}.{}", table.schema_name, table.name);
128 files.push(GeneratedFile {
129 filename: format!("{}.rs", module_name),
130 origin: Some(origin),
131 code,
132 });
133 }
134
135 for view in &schema_info.views {
137 let (tokens, imports) =
138 struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
139 let imports = filter_imports(&imports, single_file);
140 let code = format_tokens_with_imports(&tokens, &imports);
141 let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
142 let origin = format!("View: {}.{}", view.schema_name, view.name);
143 files.push(GeneratedFile {
144 filename: format!("{}.rs", module_name),
145 origin: Some(origin),
146 code,
147 });
148 }
149
150 let mut types_blocks: Vec<String> = Vec::new();
153 let mut types_imports = BTreeSet::new();
154
155 for enum_info in &schema_info.enums {
156 let (tokens, imports) = enum_gen::generate_enum(enum_info, db_kind, extra_derives);
157 types_blocks.push(format_tokens(&tokens));
158 types_imports.extend(imports);
159 }
160
161 for composite in &schema_info.composite_types {
162 let (tokens, imports) = composite_gen::generate_composite(
163 composite,
164 db_kind,
165 schema_info,
166 extra_derives,
167 type_overrides,
168 );
169 types_blocks.push(format_tokens(&tokens));
170 types_imports.extend(imports);
171 }
172
173 for domain in &schema_info.domains {
174 let (tokens, imports) =
175 domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
176 types_blocks.push(format_tokens(&tokens));
177 types_imports.extend(imports);
178 }
179
180 if !types_blocks.is_empty() {
181 let import_lines: String = types_imports
182 .iter()
183 .map(|i| format!("{}\n", i))
184 .collect();
185 let body = types_blocks.join("\n");
186 let code = if import_lines.is_empty() {
187 body
188 } else {
189 format!("{}\n\n{}", import_lines.trim_end(), body)
190 };
191 files.push(GeneratedFile {
192 filename: "types.rs".to_string(),
193 origin: None,
194 code,
195 });
196 }
197
198 files
199}
200
201fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
203 if single_file {
204 imports
205 .iter()
206 .filter(|i| !i.contains("super::types::"))
207 .cloned()
208 .collect()
209 } else {
210 imports.clone()
211 }
212}
213
214pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
216 let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
217 log::error!("Failed to parse generated code: {}", e);
218 log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens);
219 std::process::exit(1);
220 });
221 let raw = prettyplease::unparse(&file);
222 add_blank_lines_between_items(&raw)
223}
224
225pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
227 parse_and_format(tokens)
228}
229
230pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
231 let formatted = parse_and_format(tokens);
232
233 let used_imports: Vec<&String> = imports
234 .iter()
235 .filter(|imp| is_import_used(imp, &formatted))
236 .collect();
237
238 if used_imports.is_empty() {
239 formatted
240 } else {
241 let import_lines: String = used_imports
242 .iter()
243 .map(|i| format!("{}\n", i))
244 .collect();
245 format!("{}\n\n{}", import_lines.trim_end(), formatted)
246 }
247}
248
249fn is_import_used(import: &str, code: &str) -> bool {
252 let trimmed = import.trim().trim_end_matches(';');
256 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
257
258 if path.ends_with("::*") {
259 return true;
260 }
261
262 if let Some(start) = path.find('{') {
264 if let Some(end) = path.find('}') {
265 let names = &path[start + 1..end];
266 return names
267 .split(',')
268 .map(|n| n.trim())
269 .filter(|n| !n.is_empty())
270 .any(|name| code.contains(name));
271 }
272 }
273
274 if let Some(name) = path.rsplit("::").next() {
276 return code.contains(name);
277 }
278
279 true
280}
281
282fn add_blank_lines_between_items(code: &str) -> String {
287 let lines: Vec<&str> = code.lines().collect();
288 let mut result = Vec::with_capacity(lines.len());
289
290 for (i, line) in lines.iter().enumerate() {
291 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
294 let prev = lines[i - 1].trim();
295 if prev.ends_with(',') {
296 result.push("");
297 }
298 }
299
300 if i > 0 {
303 let trimmed = line.trim();
304 let prev = lines[i - 1].trim();
305 if prev == "}"
306 && (trimmed.starts_with("pub struct")
307 || trimmed.starts_with("impl ")
308 || trimmed.starts_with("#[derive")
309 || trimmed.starts_with("pub async fn")
310 || trimmed.starts_with("pub fn"))
311 {
312 result.push("");
313 }
314 }
315
316 if i > 0 {
320 let trimmed = line.trim();
321 let prev = lines[i - 1].trim();
322 let prev_is_await_end = prev.ends_with(".await?;")
323 || prev.ends_with(".await?")
324 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
325 if prev_is_await_end
326 && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
327 {
328 result.push("");
329 }
330 if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
332 && prev.starts_with("let ") && !prev.contains("sqlx::")
333 {
334 result.push("");
335 }
336 }
337
338 result.push(line);
339 }
340
341 result.join("\n")
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use crate::introspect::{
348 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
349 };
350 use std::collections::HashMap;
351
352 #[test]
355 fn test_keyword_type() {
356 assert!(is_rust_keyword("type"));
357 }
358
359 #[test]
360 fn test_keyword_fn() {
361 assert!(is_rust_keyword("fn"));
362 }
363
364 #[test]
365 fn test_keyword_let() {
366 assert!(is_rust_keyword("let"));
367 }
368
369 #[test]
370 fn test_keyword_match() {
371 assert!(is_rust_keyword("match"));
372 }
373
374 #[test]
375 fn test_keyword_async() {
376 assert!(is_rust_keyword("async"));
377 }
378
379 #[test]
380 fn test_keyword_await() {
381 assert!(is_rust_keyword("await"));
382 }
383
384 #[test]
385 fn test_keyword_yield() {
386 assert!(is_rust_keyword("yield"));
387 }
388
389 #[test]
390 fn test_keyword_abstract() {
391 assert!(is_rust_keyword("abstract"));
392 }
393
394 #[test]
395 fn test_keyword_try() {
396 assert!(is_rust_keyword("try"));
397 }
398
399 #[test]
400 fn test_not_keyword_name() {
401 assert!(!is_rust_keyword("name"));
402 }
403
404 #[test]
405 fn test_not_keyword_id() {
406 assert!(!is_rust_keyword("id"));
407 }
408
409 #[test]
410 fn test_not_keyword_uppercase_type() {
411 assert!(!is_rust_keyword("Type"));
412 }
413
414 #[test]
417 fn test_normalize_no_underscores() {
418 assert_eq!(normalize_module_name("users"), "users");
419 }
420
421 #[test]
422 fn test_normalize_single_underscore() {
423 assert_eq!(normalize_module_name("user_roles"), "user_roles");
424 }
425
426 #[test]
427 fn test_normalize_double_underscore() {
428 assert_eq!(normalize_module_name("user__roles"), "user_roles");
429 }
430
431 #[test]
432 fn test_normalize_triple_underscore() {
433 assert_eq!(normalize_module_name("a___b"), "a_b");
434 }
435
436 #[test]
437 fn test_normalize_leading_underscore() {
438 assert_eq!(normalize_module_name("_private"), "_private");
439 }
440
441 #[test]
442 fn test_normalize_trailing_underscore() {
443 assert_eq!(normalize_module_name("name_"), "name_");
444 }
445
446 #[test]
447 fn test_normalize_double_leading() {
448 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
449 }
450
451 #[test]
452 fn test_normalize_multiple_groups() {
453 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
454 }
455
456 #[test]
459 fn test_build_no_collision_no_prefix() {
460 assert_eq!(build_module_name("public", "users", false), "users");
461 }
462
463 #[test]
464 fn test_build_no_collision_non_default_no_prefix() {
465 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
466 }
467
468 #[test]
469 fn test_build_collision_prefixed() {
470 assert_eq!(build_module_name("billing", "users", true), "billing_users");
471 }
472
473 #[test]
474 fn test_build_collision_default_schema_no_prefix() {
475 assert_eq!(build_module_name("public", "users", true), "users");
476 }
477
478 #[test]
479 fn test_build_collision_normalizes_double_underscore() {
480 assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
481 }
482
483 #[test]
486 fn test_default_schema_public() {
487 assert!(is_default_schema("public"));
488 }
489
490 #[test]
491 fn test_default_schema_main() {
492 assert!(is_default_schema("main"));
493 }
494
495 #[test]
496 fn test_non_default_schema() {
497 assert!(!is_default_schema("billing"));
498 }
499
500 #[test]
503 fn test_imports_empty() {
504 let result = imports_for_derives(&[]);
505 assert!(result.is_empty());
506 }
507
508 #[test]
509 fn test_imports_serialize_only() {
510 let derives = vec!["Serialize".to_string()];
511 let result = imports_for_derives(&derives);
512 assert_eq!(result, vec!["use serde::{Serialize};"]);
513 }
514
515 #[test]
516 fn test_imports_deserialize_only() {
517 let derives = vec!["Deserialize".to_string()];
518 let result = imports_for_derives(&derives);
519 assert_eq!(result, vec!["use serde::{Deserialize};"]);
520 }
521
522 #[test]
523 fn test_imports_both_serde() {
524 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
525 let result = imports_for_derives(&derives);
526 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
527 }
528
529 #[test]
530 fn test_imports_non_serde() {
531 let derives = vec!["Hash".to_string()];
532 let result = imports_for_derives(&derives);
533 assert!(result.is_empty());
534 }
535
536 #[test]
537 fn test_imports_non_serde_multiple() {
538 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
539 let result = imports_for_derives(&derives);
540 assert!(result.is_empty());
541 }
542
543 #[test]
544 fn test_imports_mixed_serde_and_others() {
545 let derives = vec![
546 "Serialize".to_string(),
547 "Hash".to_string(),
548 "Deserialize".to_string(),
549 ];
550 let result = imports_for_derives(&derives);
551 assert_eq!(result.len(), 1);
552 assert!(result[0].contains("Serialize"));
553 assert!(result[0].contains("Deserialize"));
554 }
555
556 #[test]
559 fn test_blank_lines_between_renamed_variants() {
560 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
561 let result = add_blank_lines_between_items(input);
562 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
563 }
564
565 #[test]
566 fn test_no_blank_line_for_first_variant() {
567 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
568 let result = add_blank_lines_between_items(input);
569 assert!(!result.contains("{\n\n"));
571 }
572
573 #[test]
574 fn test_no_change_without_rename() {
575 let input = "pub enum Foo {\n A,\n B,\n}";
576 let result = add_blank_lines_between_items(input);
577 assert_eq!(result, input);
578 }
579
580 #[test]
581 fn test_no_change_for_struct() {
582 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
583 let result = add_blank_lines_between_items(input);
584 assert_eq!(result, input);
585 }
586
587 #[test]
590 fn test_filter_single_file_strips_super_types() {
591 let mut imports = BTreeSet::new();
592 imports.insert("use super::types::Foo;".to_string());
593 imports.insert("use chrono::NaiveDateTime;".to_string());
594 let result = filter_imports(&imports, true);
595 assert!(!result.contains("use super::types::Foo;"));
596 assert!(result.contains("use chrono::NaiveDateTime;"));
597 }
598
599 #[test]
600 fn test_filter_single_file_keeps_other_imports() {
601 let mut imports = BTreeSet::new();
602 imports.insert("use chrono::NaiveDateTime;".to_string());
603 let result = filter_imports(&imports, true);
604 assert!(result.contains("use chrono::NaiveDateTime;"));
605 }
606
607 #[test]
608 fn test_filter_multi_file_keeps_all() {
609 let mut imports = BTreeSet::new();
610 imports.insert("use super::types::Foo;".to_string());
611 imports.insert("use chrono::NaiveDateTime;".to_string());
612 let result = filter_imports(&imports, false);
613 assert_eq!(result.len(), 2);
614 }
615
616 #[test]
617 fn test_filter_empty_set() {
618 let imports = BTreeSet::new();
619 let result = filter_imports(&imports, true);
620 assert!(result.is_empty());
621 }
622
623 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
626 TableInfo {
627 schema_name: "public".to_string(),
628 name: name.to_string(),
629 columns,
630 }
631 }
632
633 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
634 ColumnInfo {
635 name: name.to_string(),
636 data_type: udt_name.to_string(),
637 udt_name: udt_name.to_string(),
638 is_nullable: false,
639 is_primary_key: false,
640 ordinal_position: 0,
641 schema_name: "public".to_string(),
642 }
643 }
644
645 #[test]
646 fn test_generate_empty_schema() {
647 let schema = SchemaInfo::default();
648 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
649 assert!(files.is_empty());
650 }
651
652 #[test]
653 fn test_generate_one_table() {
654 let schema = SchemaInfo {
655 tables: vec![make_table("users", vec![make_col("id", "int4")])],
656 ..Default::default()
657 };
658 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
659 assert_eq!(files.len(), 1);
660 assert_eq!(files[0].filename, "users.rs");
661 }
662
663 #[test]
664 fn test_generate_two_tables() {
665 let schema = SchemaInfo {
666 tables: vec![
667 make_table("users", vec![make_col("id", "int4")]),
668 make_table("posts", vec![make_col("id", "int4")]),
669 ],
670 ..Default::default()
671 };
672 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
673 assert_eq!(files.len(), 2);
674 }
675
676 #[test]
677 fn test_generate_enum_creates_types_file() {
678 let schema = SchemaInfo {
679 enums: vec![EnumInfo {
680 schema_name: "public".to_string(),
681 name: "status".to_string(),
682 variants: vec!["active".to_string(), "inactive".to_string()],
683 }],
684 ..Default::default()
685 };
686 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
687 assert_eq!(files.len(), 1);
688 assert_eq!(files[0].filename, "types.rs");
689 }
690
691 #[test]
692 fn test_generate_enums_composites_domains_single_types_file() {
693 let schema = SchemaInfo {
694 enums: vec![EnumInfo {
695 schema_name: "public".to_string(),
696 name: "status".to_string(),
697 variants: vec!["active".to_string()],
698 }],
699 composite_types: vec![CompositeTypeInfo {
700 schema_name: "public".to_string(),
701 name: "address".to_string(),
702 fields: vec![make_col("street", "text")],
703 }],
704 domains: vec![DomainInfo {
705 schema_name: "public".to_string(),
706 name: "email".to_string(),
707 base_type: "text".to_string(),
708 }],
709 ..Default::default()
710 };
711 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
712 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
714 assert_eq!(types_files.len(), 1);
715 }
716
717 #[test]
718 fn test_generate_tables_and_enums() {
719 let schema = SchemaInfo {
720 tables: vec![make_table("users", vec![make_col("id", "int4")])],
721 enums: vec![EnumInfo {
722 schema_name: "public".to_string(),
723 name: "status".to_string(),
724 variants: vec!["active".to_string()],
725 }],
726 ..Default::default()
727 };
728 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
729 assert_eq!(files.len(), 2); }
731
732 #[test]
733 fn test_generate_filename_normalized() {
734 let schema = SchemaInfo {
735 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
736 ..Default::default()
737 };
738 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
739 assert_eq!(files[0].filename, "user_data.rs");
740 }
741
742 #[test]
743 fn test_generate_origin_correct() {
744 let schema = SchemaInfo {
745 tables: vec![make_table("users", vec![make_col("id", "int4")])],
746 ..Default::default()
747 };
748 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
749 assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
750 }
751
752 #[test]
753 fn test_generate_types_no_origin() {
754 let schema = SchemaInfo {
755 enums: vec![EnumInfo {
756 schema_name: "public".to_string(),
757 name: "status".to_string(),
758 variants: vec!["active".to_string()],
759 }],
760 ..Default::default()
761 };
762 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
763 assert_eq!(files[0].origin, None);
764 }
765
766 #[test]
767 fn test_generate_single_file_filters_super_types_imports() {
768 let schema = SchemaInfo {
769 tables: vec![make_table("users", vec![make_col("id", "int4")])],
770 enums: vec![EnumInfo {
771 schema_name: "public".to_string(),
772 name: "status".to_string(),
773 variants: vec!["active".to_string()],
774 }],
775 ..Default::default()
776 };
777 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
778 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
780 assert!(!struct_file.code.contains("super::types::"));
781 }
782
783 #[test]
784 fn test_generate_multi_file_keeps_super_types_imports() {
785 let schema = SchemaInfo {
787 tables: vec![make_table("users", vec![make_col("status", "status")])],
788 enums: vec![EnumInfo {
789 schema_name: "public".to_string(),
790 name: "status".to_string(),
791 variants: vec!["active".to_string()],
792 }],
793 ..Default::default()
794 };
795 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
796 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
797 assert!(struct_file.code.contains("super::types::"));
798 }
799
800 #[test]
801 fn test_generate_extra_derives_in_struct() {
802 let schema = SchemaInfo {
803 tables: vec![make_table("users", vec![make_col("id", "int4")])],
804 ..Default::default()
805 };
806 let derives = vec!["Serialize".to_string()];
807 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
808 assert!(files[0].code.contains("Serialize"));
809 }
810
811 #[test]
812 fn test_generate_extra_derives_in_enum() {
813 let schema = SchemaInfo {
814 enums: vec![EnumInfo {
815 schema_name: "public".to_string(),
816 name: "status".to_string(),
817 variants: vec!["active".to_string()],
818 }],
819 ..Default::default()
820 };
821 let derives = vec!["Serialize".to_string()];
822 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
823 assert!(files[0].code.contains("Serialize"));
824 }
825
826 #[test]
827 fn test_generate_type_overrides_in_struct() {
828 let mut overrides = HashMap::new();
829 overrides.insert("jsonb".to_string(), "MyJson".to_string());
830 let schema = SchemaInfo {
831 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
832 ..Default::default()
833 };
834 let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
835 assert!(files[0].code.contains("MyJson"));
836 }
837
838 #[test]
839 fn test_generate_valid_rust_syntax() {
840 let schema = SchemaInfo {
841 tables: vec![make_table("users", vec![
842 make_col("id", "int4"),
843 make_col("name", "text"),
844 ])],
845 enums: vec![EnumInfo {
846 schema_name: "public".to_string(),
847 name: "status".to_string(),
848 variants: vec!["active".to_string(), "inactive".to_string()],
849 }],
850 ..Default::default()
851 };
852 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
853 for f in &files {
854 let parse_result = syn::parse_file(&f.code);
856 assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
857 }
858 }
859
860 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
863 TableInfo {
864 schema_name: "public".to_string(),
865 name: name.to_string(),
866 columns,
867 }
868 }
869
870 #[test]
871 fn test_generate_one_view() {
872 let schema = SchemaInfo {
873 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
874 ..Default::default()
875 };
876 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
877 assert_eq!(files.len(), 1);
878 assert_eq!(files[0].filename, "active_users.rs");
879 }
880
881 #[test]
882 fn test_generate_view_origin() {
883 let schema = SchemaInfo {
884 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
885 ..Default::default()
886 };
887 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
888 assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
889 }
890
891 #[test]
892 fn test_generate_tables_and_views() {
893 let schema = SchemaInfo {
894 tables: vec![make_table("users", vec![make_col("id", "int4")])],
895 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
896 ..Default::default()
897 };
898 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
899 assert_eq!(files.len(), 2);
900 }
901
902 #[test]
903 fn test_generate_view_valid_rust() {
904 let schema = SchemaInfo {
905 views: vec![make_view("active_users", vec![
906 make_col("id", "int4"),
907 make_col("name", "text"),
908 ])],
909 ..Default::default()
910 };
911 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
912 let parse_result = syn::parse_file(&files[0].code);
913 assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
914 }
915
916 #[test]
917 fn test_generate_view_nullable_column() {
918 let schema = SchemaInfo {
919 views: vec![make_view("v", vec![ColumnInfo {
920 name: "email".to_string(),
921 data_type: "text".to_string(),
922 udt_name: "text".to_string(),
923 is_nullable: true,
924 is_primary_key: false,
925 ordinal_position: 0,
926 schema_name: "public".to_string(),
927 }])],
928 ..Default::default()
929 };
930 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
931 assert!(files[0].code.contains("Option<String>"));
932 }
933
934 #[test]
935 fn test_generate_collision_both_prefixed() {
936 let schema = SchemaInfo {
937 tables: vec![
938 make_table("users", vec![make_col("id", "int4")]),
939 TableInfo {
940 schema_name: "billing".to_string(),
941 name: "users".to_string(),
942 columns: vec![make_col("id", "int4")],
943 },
944 ],
945 ..Default::default()
946 };
947 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
948 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
949 assert!(filenames.contains(&"users.rs"));
950 assert!(filenames.contains(&"billing_users.rs"));
951 }
952
953 #[test]
954 fn test_generate_no_collision_no_prefix() {
955 let schema = SchemaInfo {
956 tables: vec![
957 make_table("users", vec![make_col("id", "int4")]),
958 TableInfo {
959 schema_name: "billing".to_string(),
960 name: "invoices".to_string(),
961 columns: vec![make_col("id", "int4")],
962 },
963 ],
964 ..Default::default()
965 };
966 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
967 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
968 assert!(filenames.contains(&"users.rs"));
969 assert!(filenames.contains(&"invoices.rs"));
970 }
971
972 #[test]
973 fn test_generate_single_schema_no_prefix() {
974 let schema = SchemaInfo {
975 tables: vec![
976 make_table("users", vec![make_col("id", "int4")]),
977 make_table("posts", vec![make_col("id", "int4")]),
978 ],
979 ..Default::default()
980 };
981 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
982 assert_eq!(files[0].filename, "users.rs");
983 assert_eq!(files[1].filename, "posts.rs");
984 }
985
986 #[test]
987 fn test_generate_view_single_file_mode() {
988 let schema = SchemaInfo {
989 tables: vec![make_table("users", vec![make_col("id", "int4")])],
990 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
991 ..Default::default()
992 };
993 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
994 assert_eq!(files.len(), 2);
995 }
996}