Skip to main content

sqlx_gen/codegen/
domain_gen.rs

1use std::collections::{BTreeSet, HashMap};
2
3use heck::ToUpperCamelCase;
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6
7use crate::cli::DatabaseKind;
8use crate::introspect::{DomainInfo, SchemaInfo};
9use crate::typemap;
10
11pub fn generate_domain(
12    domain: &DomainInfo,
13    db_kind: DatabaseKind,
14    schema_info: &SchemaInfo,
15    type_overrides: &HashMap<String, String>,
16) -> (TokenStream, BTreeSet<String>) {
17    let mut imports = BTreeSet::new();
18    let alias_name = format_ident!("{}", domain.name.to_upper_camel_case());
19
20    let doc = format!(
21        "Domain: {}.{} (base: {})",
22        domain.schema_name, domain.name, domain.base_type
23    );
24
25    // Create a fake ColumnInfo to reuse the type mapper for the base type
26    let fake_col = crate::introspect::ColumnInfo {
27        name: String::new(),
28        data_type: domain.base_type.clone(),
29        udt_name: domain.base_type.clone(),
30        is_nullable: false,
31        ordinal_position: 0,
32        schema_name: domain.schema_name.clone(),
33    };
34
35    let rust_type = typemap::map_column(&fake_col, db_kind, schema_info, type_overrides);
36    if let Some(imp) = &rust_type.needs_import {
37        imports.insert(imp.clone());
38    }
39
40    let type_tokens: TokenStream = rust_type.path.parse().unwrap_or_else(|_| {
41        let fallback = format_ident!("String");
42        quote! { #fallback }
43    });
44
45    let tokens = quote! {
46        #[doc = #doc]
47        pub type #alias_name = #type_tokens;
48    };
49
50    (tokens, imports)
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56    use crate::codegen::parse_and_format;
57
58    fn make_domain(name: &str, base: &str) -> DomainInfo {
59        DomainInfo {
60            schema_name: "public".to_string(),
61            name: name.to_string(),
62            base_type: base.to_string(),
63        }
64    }
65
66    fn gen(domain: &DomainInfo) -> (String, BTreeSet<String>) {
67        let schema = SchemaInfo::default();
68        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, &HashMap::new());
69        (parse_and_format(&tokens), imports)
70    }
71
72    fn gen_with_overrides(domain: &DomainInfo, overrides: &HashMap<String, String>) -> (String, BTreeSet<String>) {
73        let schema = SchemaInfo::default();
74        let (tokens, imports) = generate_domain(domain, DatabaseKind::Postgres, &schema, overrides);
75        (parse_and_format(&tokens), imports)
76    }
77
78    #[test]
79    fn test_domain_text() {
80        let d = make_domain("email", "text");
81        let (code, _) = gen(&d);
82        assert!(code.contains("pub type Email = String"));
83    }
84
85    #[test]
86    fn test_domain_int4() {
87        let d = make_domain("positive_int", "int4");
88        let (code, _) = gen(&d);
89        assert!(code.contains("pub type PositiveInt = i32"));
90    }
91
92    #[test]
93    fn test_domain_uuid() {
94        let d = make_domain("my_uuid", "uuid");
95        let (code, imports) = gen(&d);
96        assert!(code.contains("pub type MyUuid = Uuid"));
97        assert!(imports.iter().any(|i| i.contains("uuid::Uuid")));
98    }
99
100    #[test]
101    fn test_doc_comment() {
102        let d = make_domain("email", "text");
103        let (code, _) = gen(&d);
104        assert!(code.contains("Domain: public.email (base: text)"));
105    }
106
107    #[test]
108    fn test_import_when_needed() {
109        let d = make_domain("my_uuid", "uuid");
110        let (_, imports) = gen(&d);
111        assert!(!imports.is_empty());
112    }
113
114    #[test]
115    fn test_no_import_simple_type() {
116        let d = make_domain("email", "text");
117        let (_, imports) = gen(&d);
118        assert!(imports.is_empty());
119    }
120
121    #[test]
122    fn test_pascal_case_name() {
123        let d = make_domain("email_address", "text");
124        let (code, _) = gen(&d);
125        assert!(code.contains("pub type EmailAddress"));
126    }
127
128    #[test]
129    fn test_type_override() {
130        let d = make_domain("json_data", "jsonb");
131        let mut overrides = HashMap::new();
132        overrides.insert("jsonb".to_string(), "MyJson".to_string());
133        let (code, _) = gen_with_overrides(&d, &overrides);
134        assert!(code.contains("pub type JsonData = MyJson"));
135    }
136
137    #[test]
138    fn test_domain_jsonb() {
139        let d = make_domain("data", "jsonb");
140        let (code, imports) = gen(&d);
141        assert!(code.contains("Value"));
142        assert!(imports.iter().any(|i| i.contains("serde_json")));
143    }
144
145    #[test]
146    fn test_domain_timestamptz() {
147        let d = make_domain("created", "timestamptz");
148        let (_, imports) = gen(&d);
149        assert!(imports.iter().any(|i| i.contains("chrono")));
150    }
151}