Skip to main content

sqlx_gen/codegen/
domain_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use proc_macro2::TokenStream;
4use quote::{format_ident, quote};
5
6use crate::cli::{DatabaseKind, DomainStyle, TimeCrate};
7use crate::codegen::rust_type_name_for;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12    domain: &DomainInfo,
13    db_kind: DatabaseKind,
14    schema_info: &SchemaInfo,
15    type_overrides: &HashMap<String, String>,
16    time_crate: TimeCrate,
17) -> (TokenStream, BTreeSet<String>) {
18    generate_domain_with_style(
19        domain,
20        db_kind,
21        schema_info,
22        type_overrides,
23        time_crate,
24        DomainStyle::Alias,
25    )
26}
27
28pub fn generate_domain_with_style(
29    domain: &DomainInfo,
30    db_kind: DatabaseKind,
31    schema_info: &SchemaInfo,
32    type_overrides: &HashMap<String, String>,
33    time_crate: TimeCrate,
34    style: DomainStyle,
35) -> (TokenStream, BTreeSet<String>) {
36    let mut imports = BTreeSet::new();
37    let rust_name = rust_type_name_for(schema_info, &domain.schema_name, &domain.name);
38    let alias_name = format_ident!("{}", rust_name);
39
40    let doc = format!(
41        "Domain: {}.{} (base: {})",
42        domain.schema_name, domain.name, domain.base_type
43    );
44
45    // Create a fake ColumnInfo to reuse the type mapper for the base type
46    let fake_col = crate::introspect::ColumnInfo {
47        name: String::new(),
48        data_type: domain.base_type.clone(),
49        udt_name: domain.base_type.clone(),
50        udt_schema: None,
51        is_nullable: false,
52        is_primary_key: false,
53        ordinal_position: 0,
54        schema_name: domain.schema_name.clone(),
55        column_default: None,
56    };
57
58    let rust_type =
59        typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate);
60    if let Some(imp) = &rust_type.needs_import {
61        imports.insert(imp.clone());
62    }
63
64    let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
65        let fallback = format_ident!("String");
66        quote! { #fallback }
67    });
68
69    let domain_doc = "sqlx_gen:kind=domain";
70    let tokens = match style {
71        DomainStyle::Alias => quote! {
72            #[doc = #doc]
73            #[doc = #domain_doc]
74            pub type #alias_name = #type_tokens;
75        },
76        DomainStyle::Newtype => {
77            imports.insert("use serde::{Serialize, Deserialize};".to_string());
78            quote! {
79                #[doc = #doc]
80                #[doc = #domain_doc]
81                #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, sqlx::Type)]
82                #[sqlx(transparent)]
83                pub struct #alias_name(pub #type_tokens);
84            }
85        }
86    };
87
88    (tokens, imports)
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::codegen::parse_and_format;
95
96    fn make_domain(name: &str, base: &str) -> DomainInfo {
97        DomainInfo {
98            schema_name: "public".to_string(),
99            name: name.to_string(),
100            base_type: base.to_string(),
101        }
102    }
103
104    fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
105        let schema = SchemaInfo::default();
106        let (tokens, imports) = generate_domain(
107            domain,
108            DatabaseKind::Postgres,
109            &schema,
110            &HashMap::new(),
111            TimeCrate::Chrono,
112        );
113        (parse_and_format(&tokens).unwrap(), imports)
114    }
115
116    fn gen_with_overrides(
117        domain: &DomainInfo,
118        overrides: &HashMap<String, String>,
119    ) -> (String, BTreeSet<String>) {
120        let schema = SchemaInfo::default();
121        let (tokens, imports) = generate_domain(
122            domain,
123            DatabaseKind::Postgres,
124            &schema,
125            overrides,
126            TimeCrate::Chrono,
127        );
128        (parse_and_format(&tokens).unwrap(), imports)
129    }
130
131    #[test]
132    fn test_domain_text() {
133        let d = make_domain("email", "text");
134        let (code, _) = gen(&d);
135        assert!(code.contains("pub type Email = String"));
136    }
137
138    #[test]
139    fn test_domain_int4() {
140        let d = make_domain("positive_int", "int4");
141        let (code, _) = gen(&d);
142        assert!(code.contains("pub type PositiveInt = i32"));
143    }
144
145    #[test]
146    fn test_domain_uuid() {
147        let d = make_domain("my_uuid", "uuid");
148        let (code, imports) = gen(&d);
149        assert!(code.contains("pub type MyUuid = Uuid"));
150        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
151    }
152
153    #[test]
154    fn test_doc_comment() {
155        let d = make_domain("email", "text");
156        let (code, _) = gen(&d);
157        assert!(code.contains("Domain: public.email (base: text)"));
158    }
159
160    #[test]
161    fn test_import_when_needed() {
162        let d = make_domain("my_uuid", "uuid");
163        let (_, imports) = gen(&d);
164        assert!(!imports.is_empty());
165    }
166
167    #[test]
168    fn test_no_import_simple_type() {
169        let d = make_domain("email", "text");
170        let (_, imports) = gen(&d);
171        assert!(imports.is_empty());
172    }
173
174    #[test]
175    fn test_pascal_case_name() {
176        let d = make_domain("email_address", "text");
177        let (code, _) = gen(&d);
178        assert!(code.contains("pub type EmailAddress"));
179    }
180
181    #[test]
182    fn test_type_override() {
183        let d = make_domain("json_data", "jsonb");
184        let mut overrides = HashMap::new();
185        overrides.insert("jsonb".to_string(), "MyJson".to_string());
186        let (code, _) = gen_with_overrides(&d, &overrides);
187        assert!(code.contains("pub type JsonData = MyJson"));
188    }
189
190    #[test]
191    fn test_domain_jsonb() {
192        let d = make_domain("data", "jsonb");
193        let (code, imports) = gen(&d);
194        assert!(code.contains("Value"));
195        assert!(imports.iter().any(|i| i.contains("serde_json")));
196    }
197
198    #[test]
199    fn test_domain_timestamptz() {
200        let d = make_domain("created", "timestamptz");
201        let (_, imports) = gen(&d);
202        assert!(imports.iter().any(|i| i.contains("chrono")));
203    }
204
205    // ========== DomainStyle::Newtype ==========
206
207    fn gen_newtype(domain: &DomainInfo) -> (String, BTreeSet<String>) {
208        let schema = SchemaInfo::default();
209        let (tokens, imports) = generate_domain_with_style(
210            domain,
211            DatabaseKind::Postgres,
212            &schema,
213            &HashMap::new(),
214            TimeCrate::Chrono,
215            DomainStyle::Newtype,
216        );
217        (parse_and_format(&tokens).unwrap(), imports)
218    }
219
220    #[test]
221    fn test_newtype_emits_tuple_struct() {
222        let d = make_domain("email", "text");
223        let (code, _) = gen_newtype(&d);
224        assert!(
225            code.contains("pub struct Email(pub String)"),
226            "newtype must wrap the base type in a tuple struct, got:\n{}",
227            code
228        );
229    }
230
231    #[test]
232    fn test_newtype_uses_transparent_derive() {
233        let d = make_domain("email", "text");
234        let (code, _) = gen_newtype(&d);
235        assert!(code.contains("#[sqlx(transparent)]"));
236        assert!(code.contains("sqlx::Type"));
237    }
238
239    #[test]
240    fn test_newtype_keeps_doc_comments() {
241        let d = make_domain("email", "text");
242        let (code, _) = gen_newtype(&d);
243        assert!(code.contains("Domain: public.email (base: text)"));
244        assert!(code.contains("sqlx_gen:kind=domain"));
245    }
246
247    #[test]
248    fn test_newtype_wraps_uuid_with_import() {
249        let d = make_domain("my_uuid", "uuid");
250        let (code, imports) = gen_newtype(&d);
251        assert!(code.contains("pub struct MyUuid(pub Uuid)"));
252        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
253    }
254
255    #[test]
256    fn test_newtype_does_not_emit_type_alias() {
257        let d = make_domain("email", "text");
258        let (code, _) = gen_newtype(&d);
259        assert!(!code.contains("pub type Email"));
260    }
261}