Skip to main content

sqlx_gen/codegen/
domain_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12    domain: &DomainInfo,
13    db_kind: DatabaseKind,
14    schema_info: &SchemaInfo,
15    type_overrides: &HashMap<String, String>,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18    let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
19
20    let doc = format!(
21        "Domain: {}.{} (base: {})",
22        domain.schema_name, domain.name, domain.base_type
23    );
24
25    // Create a fake ColumnInfo to reuse the type mapper for the base type
26    let fake_col = crate::introspect::ColumnInfo {
27        name: String::new(),
28        data_type: domain.base_type.clone(),
29        udt_name: domain.base_type.clone(),
30        is_nullable: false,
31        is_primary_key: false,
32        ordinal_position: 0,
33        schema_name: domain.schema_name.clone(),
34        column_default: None,
35    };
36
37    let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides);
38    if let Some(imp) = &rust_type.needs_import {
39        imports.insert(imp.clone());
40    }
41
42    let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
43        let fallback = format_ident!("String");
44        quote! { #fallback }
45    });
46
47    let domain_doc = "sqlx_gen:kind=domain";
48    let tokens = quote! {
49        #[doc = #doc]
50        #[doc = #domain_doc]
51        pub type #alias_name = #type_tokens;
52    };
53
54    (tokens, imports)
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60    use crate::codegen::parse_and_format;
61
62    fn make_domain(name: &str, base: &str) -> DomainInfo {
63        DomainInfo {
64            schema_name: "public".to_string(),
65            name: name.to_string(),
66            base_type: base.to_string(),
67        }
68    }
69
70    fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
71        let schema = SchemaInfo::default();
72        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new());
73        (parse_and_format(&tokens), imports)
74    }
75
76    fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
77        let schema = SchemaInfo::default();
78        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides);
79        (parse_and_format(&tokens), imports)
80    }
81
82    #[test]
83    fn test_domain_text() {
84        let d = make_domain("email", "text");
85        let (code, _) = gen(&d);
86        assert!(code.contains("pub type Email = String"));
87    }
88
89    #[test]
90    fn test_domain_int4() {
91        let d = make_domain("positive_int", "int4");
92        let (code, _) = gen(&d);
93        assert!(code.contains("pub type PositiveInt = i32"));
94    }
95
96    #[test]
97    fn test_domain_uuid() {
98        let d = make_domain("my_uuid", "uuid");
99        let (code, imports) = gen(&d);
100        assert!(code.contains("pub type MyUuid = Uuid"));
101        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
102    }
103
104    #[test]
105    fn test_doc_comment() {
106        let d = make_domain("email", "text");
107        let (code, _) = gen(&d);
108        assert!(code.contains("Domain: public.email (base: text)"));
109    }
110
111    #[test]
112    fn test_import_when_needed() {
113        let d = make_domain("my_uuid", "uuid");
114        let (_, imports) = gen(&d);
115        assert!(!imports.is_empty());
116    }
117
118    #[test]
119    fn test_no_import_simple_type() {
120        let d = make_domain("email", "text");
121        let (_, imports) = gen(&d);
122        assert!(imports.is_empty());
123    }
124
125    #[test]
126    fn test_pascal_case_name() {
127        let d = make_domain("email_address", "text");
128        let (code, _) = gen(&d);
129        assert!(code.contains("pub type EmailAddress"));
130    }
131
132    #[test]
133    fn test_type_override() {
134        let d = make_domain("json_data", "jsonb");
135        let mut overrides = HashMap::new();
136        overrides.insert("jsonb".to_string(), "MyJson".to_string());
137        let (code, _) = gen_with_overrides(&d, &overrides);
138        assert!(code.contains("pub type JsonData = MyJson"));
139    }
140
141    #[test]
142    fn test_domain_jsonb() {
143        let d = make_domain("data", "jsonb");
144        let (code, imports) = gen(&d);
145        assert!(code.contains("Value"));
146        assert!(imports.iter().any(|i| i.contains("serde_json")));
147    }
148
149    #[test]
150    fn test_domain_timestamptz() {
151        let d = make_domain("created", "timestamptz");
152        let (_, imports) = gen(&d);
153        assert!(imports.iter().any(|i| i.contains("chrono")));
154    }
155}