1use std::collections::{BTreeSet, HashMap};
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, DomainStyle, TimeCrate};
7use crate::codegen::rust_type_name_for;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12 domain: &DomainInfo,
13 db_kind: DatabaseKind,
14 schema_info: &SchemaInfo,
15 type_overrides: &HashMap<String, String>,
16 time_crate: TimeCrate,
17) -> (TokenStream, BTreeSet<String>) {
18 generate_domain_with_style(
19 domain,
20 db_kind,
21 schema_info,
22 type_overrides,
23 time_crate,
24 DomainStyle::Alias,
25 )
26}
27
28pub fn generate_domain_with_style(
29 domain: &DomainInfo,
30 db_kind: DatabaseKind,
31 schema_info: &SchemaInfo,
32 type_overrides: &HashMap<String, String>,
33 time_crate: TimeCrate,
34 style: DomainStyle,
35) -> (TokenStream, BTreeSet<String>) {
36 let mut imports = BTreeSet::new();
37 let rust_name = rust_type_name_for(schema_info, &domain.schema_name, &domain.name);
38 let alias_name = format_ident!("{}", rust_name);
39
40 let doc = format!(
41 "Domain: {}.{} (base: {})",
42 domain.schema_name, domain.name, domain.base_type
43 );
44
45 let fake_col = crate::introspect::ColumnInfo {
47 name: String::new(),
48 data_type: domain.base_type.clone(),
49 udt_name: domain.base_type.clone(),
50 udt_schema: None,
51 is_nullable: false,
52 is_primary_key: false,
53 ordinal_position: 0,
54 schema_name: domain.schema_name.clone(),
55 column_default: None,
56 };
57
58 let rust_type =
59 typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate);
60 if let Some(imp) = &rust_type.needs_import {
61 imports.insert(imp.clone());
62 }
63
64 let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
65 let fallback = format_ident!("String");
66 quote! { #fallback }
67 });
68
69 let domain_doc = "sqlx_gen:kind=domain";
70 let tokens = match style {
71 DomainStyle::Alias => quote! {
72 #[doc = #doc]
73 #[doc = #domain_doc]
74 pub type #alias_name = #type_tokens;
75 },
76 DomainStyle::Newtype => {
77 imports.insert("use serde::{Serialize, Deserialize};".to_string());
78 quote! {
79 #[doc = #doc]
80 #[doc = #domain_doc]
81 #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
82 #[sqlx(transparent)]
83 pub struct #alias_name(pub #type_tokens);
84 }
85 }
86 };
87
88 (tokens, imports)
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94 use crate::codegen::parse_and_format;
95
96 fn make_domain(name: &str, base: &str) -> DomainInfo {
97 DomainInfo {
98 schema_name: "public".to_string(),
99 name: name.to_string(),
100 base_type: base.to_string(),
101 }
102 }
103
104 fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
105 let schema = SchemaInfo::default();
106 let (tokens, imports) = generate_domain(
107 domain,
108 DatabaseKind::Postgres,
109 &schema,
110 &HashMap::new(),
111 TimeCrate::Chrono,
112 );
113 (parse_and_format(&tokens).unwrap(), imports)
114 }
115
116 fn gen_with_overrides(
117 domain: &DomainInfo,
118 overrides: &HashMap<String, String>,
119 ) -> (String, BTreeSet<String>) {
120 let schema = SchemaInfo::default();
121 let (tokens, imports) = generate_domain(
122 domain,
123 DatabaseKind::Postgres,
124 &schema,
125 overrides,
126 TimeCrate::Chrono,
127 );
128 (parse_and_format(&tokens).unwrap(), imports)
129 }
130
131 #[test]
132 fn test_domain_text() {
133 let d = make_domain("email", "text");
134 let (code, _) = gen(&d);
135 assert!(code.contains("pub type Email = String"));
136 }
137
138 #[test]
139 fn test_domain_int4() {
140 let d = make_domain("positive_int", "int4");
141 let (code, _) = gen(&d);
142 assert!(code.contains("pub type PositiveInt = i32"));
143 }
144
145 #[test]
146 fn test_domain_uuid() {
147 let d = make_domain("my_uuid", "uuid");
148 let (code, imports) = gen(&d);
149 assert!(code.contains("pub type MyUuid = Uuid"));
150 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
151 }
152
153 #[test]
154 fn test_doc_comment() {
155 let d = make_domain("email", "text");
156 let (code, _) = gen(&d);
157 assert!(code.contains("Domain: public.email (base: text)"));
158 }
159
160 #[test]
161 fn test_import_when_needed() {
162 let d = make_domain("my_uuid", "uuid");
163 let (_, imports) = gen(&d);
164 assert!(!imports.is_empty());
165 }
166
167 #[test]
168 fn test_no_import_simple_type() {
169 let d = make_domain("email", "text");
170 let (_, imports) = gen(&d);
171 assert!(imports.is_empty());
172 }
173
174 #[test]
175 fn test_pascal_case_name() {
176 let d = make_domain("email_address", "text");
177 let (code, _) = gen(&d);
178 assert!(code.contains("pub type EmailAddress"));
179 }
180
181 #[test]
182 fn test_type_override() {
183 let d = make_domain("json_data", "jsonb");
184 let mut overrides = HashMap::new();
185 overrides.insert("jsonb".to_string(), "MyJson".to_string());
186 let (code, _) = gen_with_overrides(&d, &overrides);
187 assert!(code.contains("pub type JsonData = MyJson"));
188 }
189
190 #[test]
191 fn test_domain_jsonb() {
192 let d = make_domain("data", "jsonb");
193 let (code, imports) = gen(&d);
194 assert!(code.contains("Value"));
195 assert!(imports.iter().any(|i| i.contains("serde_json")));
196 }
197
198 #[test]
199 fn test_domain_timestamptz() {
200 let d = make_domain("created", "timestamptz");
201 let (_, imports) = gen(&d);
202 assert!(imports.iter().any(|i| i.contains("chrono")));
203 }
204
205 fn gen_newtype(domain: &DomainInfo) -> (String, BTreeSet<String>) {
208 let schema = SchemaInfo::default();
209 let (tokens, imports) = generate_domain_with_style(
210 domain,
211 DatabaseKind::Postgres,
212 &schema,
213 &HashMap::new(),
214 TimeCrate::Chrono,
215 DomainStyle::Newtype,
216 );
217 (parse_and_format(&tokens).unwrap(), imports)
218 }
219
220 #[test]
221 fn test_newtype_emits_tuple_struct() {
222 let d = make_domain("email", "text");
223 let (code, _) = gen_newtype(&d);
224 assert!(
225 code.contains("pub struct Email(pub String)"),
226 "newtype must wrap the base type in a tuple struct, got:\n{}",
227 code
228 );
229 }
230
231 #[test]
232 fn test_newtype_uses_transparent_derive() {
233 let d = make_domain("email", "text");
234 let (code, _) = gen_newtype(&d);
235 assert!(code.contains("#[sqlx(transparent)]"));
236 assert!(code.contains("sqlx::Type"));
237 }
238
239 #[test]
240 fn test_newtype_keeps_doc_comments() {
241 let d = make_domain("email", "text");
242 let (code, _) = gen_newtype(&d);
243 assert!(code.contains("Domain: public.email (base: text)"));
244 assert!(code.contains("sqlx_gen:kind=domain"));
245 }
246
247 #[test]
248 fn test_newtype_wraps_uuid_with_import() {
249 let d = make_domain("my_uuid", "uuid");
250 let (code, imports) = gen_newtype(&d);
251 assert!(code.contains("pub struct MyUuid(pub Uuid)"));
252 assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
253 }
254
255 #[test]
256 fn test_newtype_does_not_emit_type_alias() {
257 let d = make_domain("email", "text");
258 let (code, _) = gen_newtype(&d);
259 assert!(!code.contains("pub type Email"));
260 }
261}