sqlx_gen/codegen/
domain_gen.rs1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12 domain: &DomainInfo,
13 db_kind: DatabaseKind,
14 schema_info: &SchemaInfo,
15 type_overrides: &HashMap<String, String>,
16) -> (TokenStream, BTreeSet<String>) {
17 let mut imports = BTreeSet::new();
18 let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
19
20 let doc = format!(
21 "Domain: {}.{} (base: {})",
22 domain.schema_name, domain.name, domain.base_type
23 );
24
25 let fake_col = crate::introspect::ColumnInfo {
27 name: String::new(),
28 data_type: domain.base_type.clone(),
29 udt_name: domain.base_type.clone(),
30 is_nullable: false,
31 is_primary_key: false,
32 ordinal_position: 0,
33 schema_name: domain.schema_name.clone(),
34 column_default: None,
35 };
36
37 let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides);
38 if let Some(imp) = &rust_type.needs_import {
39 imports.insert(imp.clone());
40 }
41
42 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
43 let fallback = format_ident!("String");
44 quote! { #fallback }
45 });
46
47 let domain_doc = "sqlx_gen:kind=domain";
48 let tokens = quote! {
49 #[doc = #doc]
50 #[doc = #domain_doc]
51 pub type #alias_name = #type_tokens;
52 };
53
54 (tokens, imports)
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60 use crate::codegen::parse_and_format;
61
62 fn make_domain(name: &str, base: &str) -> DomainInfo {
63 DomainInfo {
64 schema_name: "public".to_string(),
65 name: name.to_string(),
66 base_type: base.to_string(),
67 }
68 }
69
70 fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
71 let schema = SchemaInfo::default();
72 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new());
73 (parse_and_format(&tokens), imports)
74 }
75
76 fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
77 let schema = SchemaInfo::default();
78 let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides);
79 (parse_and_format(&tokens), imports)
80 }
81
82 #[test]
83 fn test_domain_text() {
84 let d = make_domain("email", "text");
85 let (code, _) = gen(&d);
86 assert!(code.contains("pub type Email = String"));
87 }
88
89 #[test]
90 fn test_domain_int4() {
91 let d = make_domain("positive_int", "int4");
92 let (code, _) = gen(&d);
93 assert!(code.contains("pub type PositiveInt = i32"));
94 }
95
96 #[test]
97 fn test_domain_uuid() {
98 let d = make_domain("my_uuid", "uuid");
99 let (code, imports) = gen(&d);
100 assert!(code.contains("pub type MyUuid = Uuid"));
101 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
102 }
103
104 #[test]
105 fn test_doc_comment() {
106 let d = make_domain("email", "text");
107 let (code, _) = gen(&d);
108 assert!(code.contains("Domain: public.email (base: text)"));
109 }
110
111 #[test]
112 fn test_import_when_needed() {
113 let d = make_domain("my_uuid", "uuid");
114 let (_, imports) = gen(&d);
115 assert!(!imports.is_empty());
116 }
117
118 #[test]
119 fn test_no_import_simple_type() {
120 let d = make_domain("email", "text");
121 let (_, imports) = gen(&d);
122 assert!(imports.is_empty());
123 }
124
125 #[test]
126 fn test_pascal_case_name() {
127 let d = make_domain("email_address", "text");
128 let (code, _) = gen(&d);
129 assert!(code.contains("pub type EmailAddress"));
130 }
131
132 #[test]
133 fn test_type_override() {
134 let d = make_domain("json_data", "jsonb");
135 let mut overrides = HashMap::new();
136 overrides.insert("jsonb".to_string(), "MyJson".to_string());
137 let (code, _) = gen_with_overrides(&d, &overrides);
138 assert!(code.contains("pub type JsonData = MyJson"));
139 }
140
141 #[test]
142 fn test_domain_jsonb() {
143 let d = make_domain("data", "jsonb");
144 let (code, imports) = gen(&d);
145 assert!(code.contains("Value"));
146 assert!(imports.iter().any(|i| i.contains("serde_json")));
147 }
148
149 #[test]
150 fn test_domain_timestamptz() {
151 let d = make_domain("created", "timestamptz");
152 let (_, imports) = gen(&d);
153 assert!(imports.iter().any(|i| i.contains("chrono")));
154 }
155}