Skip to main content

sqlx_gen/codegen/
domain_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::{DatabaseKind, TimeCrate};
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12    domain: &DomainInfo,
13    db_kind: DatabaseKind,
14    schema_info: &SchemaInfo,
15    type_overrides: &HashMap<String, String>,
16    time_crate: TimeCrate,
17) -> (TokenStream, BTreeSet<String>) {
18    let mut imports = BTreeSet::new();
19    let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
20
21    let doc = format!(
22        "Domain: {}.{} (base: {})",
23        domain.schema_name, domain.name, domain.base_type
24    );
25
26    // Create a fake ColumnInfo to reuse the type mapper for the base type
27    let fake_col = crate::introspect::ColumnInfo {
28        name: String::new(),
29        data_type: domain.base_type.clone(),
30        udt_name: domain.base_type.clone(),
31        is_nullable: false,
32        is_primary_key: false,
33        ordinal_position: 0,
34        schema_name: domain.schema_name.clone(),
35        column_default: None,
36    };
37
38    let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides, time_crate);
39    if let Some(imp) = &rust_type.needs_import {
40        imports.insert(imp.clone());
41    }
42
43    let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
44        let fallback = format_ident!("String");
45        quote! { #fallback }
46    });
47
48    let domain_doc = "sqlx_gen:kind=domain";
49    let tokens = quote! {
50        #[doc = #doc]
51        #[doc = #domain_doc]
52        pub type #alias_name = #type_tokens;
53    };
54
55    (tokens, imports)
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61    use crate::codegen::parse_and_format;
62
63    fn make_domain(name: &str, base: &str) -> DomainInfo {
64        DomainInfo {
65            schema_name: "public".to_string(),
66            name: name.to_string(),
67            base_type: base.to_string(),
68        }
69    }
70
71    fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
72        let schema = SchemaInfo::default();
73        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new(), TimeCrate::Chrono);
74        (parse_and_format(&tokens), imports)
75    }
76
77    fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
78        let schema = SchemaInfo::default();
79        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides, TimeCrate::Chrono);
80        (parse_and_format(&tokens), imports)
81    }
82
83    #[test]
84    fn test_domain_text() {
85        let d = make_domain("email", "text");
86        let (code, _) = gen(&d);
87        assert!(code.contains("pub type Email = String"));
88    }
89
90    #[test]
91    fn test_domain_int4() {
92        let d = make_domain("positive_int", "int4");
93        let (code, _) = gen(&d);
94        assert!(code.contains("pub type PositiveInt = i32"));
95    }
96
97    #[test]
98    fn test_domain_uuid() {
99        let d = make_domain("my_uuid", "uuid");
100        let (code, imports) = gen(&d);
101        assert!(code.contains("pub type MyUuid = Uuid"));
102        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
103    }
104
105    #[test]
106    fn test_doc_comment() {
107        let d = make_domain("email", "text");
108        let (code, _) = gen(&d);
109        assert!(code.contains("Domain: public.email (base: text)"));
110    }
111
112    #[test]
113    fn test_import_when_needed() {
114        let d = make_domain("my_uuid", "uuid");
115        let (_, imports) = gen(&d);
116        assert!(!imports.is_empty());
117    }
118
119    #[test]
120    fn test_no_import_simple_type() {
121        let d = make_domain("email", "text");
122        let (_, imports) = gen(&d);
123        assert!(imports.is_empty());
124    }
125
126    #[test]
127    fn test_pascal_case_name() {
128        let d = make_domain("email_address", "text");
129        let (code, _) = gen(&d);
130        assert!(code.contains("pub type EmailAddress"));
131    }
132
133    #[test]
134    fn test_type_override() {
135        let d = make_domain("json_data", "jsonb");
136        let mut overrides = HashMap::new();
137        overrides.insert("jsonb".to_string(), "MyJson".to_string());
138        let (code, _) = gen_with_overrides(&d, &overrides);
139        assert!(code.contains("pub type JsonData = MyJson"));
140    }
141
142    #[test]
143    fn test_domain_jsonb() {
144        let d = make_domain("data", "jsonb");
145        let (code, imports) = gen(&d);
146        assert!(code.contains("Value"));
147        assert!(imports.iter().any(|i| i.contains("serde_json")));
148    }
149
150    #[test]
151    fn test_domain_timestamptz() {
152        let d = make_domain("created", "timestamptz");
153        let (_, imports) = gen(&d);
154        assert!(imports.iter().any(|i| i.contains("chrono")));
155    }
156}